mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-25 17:39:20 +00:00
添加多种模式
This commit is contained in:
220
server.go
220
server.go
@@ -1,8 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,115 +8,69 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TRPServer struct {
|
||||
tcpPort int
|
||||
const (
|
||||
VERIFY_EER = "vkey"
|
||||
WORK_MAIN = "main"
|
||||
WORK_CHAN = "chan"
|
||||
RES_SIGN = "sign"
|
||||
RES_MSG = "msg0"
|
||||
)
|
||||
|
||||
type HttpModeServer struct {
|
||||
Tunnel
|
||||
httpPort int
|
||||
listener *net.TCPListener
|
||||
connList chan net.Conn
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRPServer(tcpPort, httpPort int) *TRPServer {
|
||||
s := new(TRPServer)
|
||||
s.tcpPort = tcpPort
|
||||
func NewHttpModeServer(tcpPort, httpPort int) *HttpModeServer {
|
||||
s := new(HttpModeServer)
|
||||
s.tunnelPort = tcpPort
|
||||
s.httpPort = httpPort
|
||||
s.connList = make(chan net.Conn, 1000)
|
||||
s.signalList = make(chan *Conn, 1000)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *TRPServer) Start() error {
|
||||
var err error
|
||||
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
|
||||
//开始
|
||||
func (s *HttpModeServer) Start() (error) {
|
||||
err := s.StartTunnel()
|
||||
if err != nil {
|
||||
log.Fatalln("开启客户端失败!", err)
|
||||
return err
|
||||
}
|
||||
go s.httpserver()
|
||||
return s.tcpserver()
|
||||
s.startHttpServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TRPServer) Close() error {
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
return err
|
||||
}
|
||||
return errors.New("TCP实例未创建!")
|
||||
}
|
||||
|
||||
func (s *TRPServer) tcpserver() error {
|
||||
var err error
|
||||
for {
|
||||
conn, err := s.listener.AcceptTCP()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.cliProcess(conn)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func badRequest(w http.ResponseWriter) {
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func (s *TRPServer) httpserver() {
|
||||
//开启http端口监听
|
||||
func (s *HttpModeServer) startHttpServer() {
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
retry:
|
||||
if len(s.connList) == 0 {
|
||||
badRequest(w)
|
||||
if len(s.signalList) == 0 {
|
||||
BadRequest(w)
|
||||
return
|
||||
}
|
||||
conn := <-s.connList
|
||||
log.Println(r.RequestURI)
|
||||
err := s.write(r, conn)
|
||||
conn := <-s.signalList
|
||||
if err := s.writeRequest(r, conn); err != nil {
|
||||
log.Println(err)
|
||||
conn.Close()
|
||||
goto retry
|
||||
return
|
||||
}
|
||||
err = s.writeResponse(w, conn)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
conn.Close()
|
||||
goto retry
|
||||
return
|
||||
}
|
||||
err = s.read(w, conn)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
conn.Close()
|
||||
goto retry
|
||||
return
|
||||
}
|
||||
s.connList <- conn
|
||||
conn = nil
|
||||
s.signalList <- conn
|
||||
})
|
||||
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
|
||||
}
|
||||
|
||||
func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
|
||||
vval := make([]byte, 20)
|
||||
_, err := conn.Read(vval)
|
||||
if err != nil {
|
||||
log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
if bytes.Compare(vval, getverifyval()[:]) != 0 {
|
||||
log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
|
||||
conn.Write([]byte("vkey"))
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
log.Println("连接新的客户端:", conn.RemoteAddr())
|
||||
conn.SetKeepAlive(true)
|
||||
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
||||
s.connList <- conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
|
||||
//req转为bytes发送给client端
|
||||
func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error {
|
||||
raw, err := EncodeRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -133,41 +85,21 @@ func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
|
||||
val := make([]byte, 4)
|
||||
_, err := conn.Read(val)
|
||||
//从client读取出Response
|
||||
func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error {
|
||||
flags, err := c.ReadFlag()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
flags := string(val)
|
||||
switch flags {
|
||||
case "sign":
|
||||
_, err = conn.Read(val)
|
||||
case RES_SIGN:
|
||||
nlen, err := c.GetLen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nlen := int(binary.LittleEndian.Uint32(val))
|
||||
if nlen == 0 {
|
||||
return errors.New("读取客户端长度错误。")
|
||||
}
|
||||
log.Println("收到客户端数据,需要读取长度:", nlen)
|
||||
raw := make([]byte, 0)
|
||||
buff := make([]byte, 1024)
|
||||
c := 0
|
||||
for {
|
||||
clen, err := conn.Read(buff)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
raw = append(raw, buff[:clen]...)
|
||||
c += clen
|
||||
if c >= nlen {
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Println("读取完成,长度:", c, "实际raw长度:", len(raw))
|
||||
if c != nlen {
|
||||
return fmt.Errorf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, nlen)
|
||||
raw, err := c.ReadLen(nlen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := DecodeResponse(raw)
|
||||
if err != nil {
|
||||
@@ -184,10 +116,70 @@ func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
w.Write(bodyBytes)
|
||||
case "msg0":
|
||||
return nil
|
||||
case RES_MSG:
|
||||
BadRequest(w)
|
||||
return errors.New("客户端请求出错")
|
||||
default:
|
||||
log.Println("无法解析此错误", string(val))
|
||||
BadRequest(w)
|
||||
return errors.New("无法解析此错误")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TunnelModeServer struct {
|
||||
Tunnel
|
||||
httpPort int
|
||||
tunnelTarget string
|
||||
}
|
||||
|
||||
func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer {
|
||||
s := new(TunnelModeServer)
|
||||
s.tunnelPort = tcpPort
|
||||
s.httpPort = httpPort
|
||||
s.tunnelTarget = tunnelTarget
|
||||
s.tunnelList = make(chan *Conn, 1000)
|
||||
s.signalList = make(chan *Conn, 10)
|
||||
return s
|
||||
}
|
||||
|
||||
//开始
|
||||
func (s *TunnelModeServer) Start() (error) {
|
||||
err := s.StartTunnel()
|
||||
if err != nil {
|
||||
log.Fatalln("开启客户端失败!", err)
|
||||
return err
|
||||
}
|
||||
s.startTunnelServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
//隧道模式server
|
||||
func (s *TunnelModeServer) startTunnelServer() {
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.httpPort, ""})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
for {
|
||||
conn, err := listener.AcceptTCP()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
go s.process(NewConn(conn))
|
||||
}
|
||||
}
|
||||
|
||||
//监听连接处理
|
||||
func (s *TunnelModeServer) process(c *Conn) error {
|
||||
retry:
|
||||
if len(s.tunnelList) < 10 { //新建通道
|
||||
go s.newChan()
|
||||
}
|
||||
link := <-s.tunnelList
|
||||
if _, err := link.WriteHost(s.tunnelTarget); err != nil {
|
||||
goto retry
|
||||
}
|
||||
go io.Copy(link, c)
|
||||
io.Copy(c, link.conn)
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user