添加多种模式

This commit is contained in:
刘河
2018-11-29 19:55:24 +08:00
parent 2463116b37
commit 3ea895feb5
8 changed files with 831 additions and 215 deletions

220
server.go
View File

@@ -1,8 +1,6 @@
package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -10,115 +8,69 @@ import (
"log"
"net"
"net/http"
"sync"
"time"
)
type TRPServer struct {
tcpPort int
const (
VERIFY_EER = "vkey"
WORK_MAIN = "main"
WORK_CHAN = "chan"
RES_SIGN = "sign"
RES_MSG = "msg0"
)
type HttpModeServer struct {
Tunnel
httpPort int
listener *net.TCPListener
connList chan net.Conn
sync.RWMutex
}
func NewRPServer(tcpPort, httpPort int) *TRPServer {
s := new(TRPServer)
s.tcpPort = tcpPort
func NewHttpModeServer(tcpPort, httpPort int) *HttpModeServer {
s := new(HttpModeServer)
s.tunnelPort = tcpPort
s.httpPort = httpPort
s.connList = make(chan net.Conn, 1000)
s.signalList = make(chan *Conn, 1000)
return s
}
func (s *TRPServer) Start() error {
var err error
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
//开始
func (s *HttpModeServer) Start() (error) {
err := s.StartTunnel()
if err != nil {
log.Fatalln("开启客户端失败!", err)
return err
}
go s.httpserver()
return s.tcpserver()
s.startHttpServer()
return nil
}
func (s *TRPServer) Close() error {
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return errors.New("TCP实例未创建")
}
func (s *TRPServer) tcpserver() error {
var err error
for {
conn, err := s.listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
go s.cliProcess(conn)
}
return err
}
func badRequest(w http.ResponseWriter) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
}
func (s *TRPServer) httpserver() {
//开启http端口监听
func (s *HttpModeServer) startHttpServer() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
retry:
if len(s.connList) == 0 {
badRequest(w)
if len(s.signalList) == 0 {
BadRequest(w)
return
}
conn := <-s.connList
log.Println(r.RequestURI)
err := s.write(r, conn)
conn := <-s.signalList
if err := s.writeRequest(r, conn); err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
err = s.writeResponse(w, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
err = s.read(w, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
s.connList <- conn
conn = nil
s.signalList <- conn
})
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
}
func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
vval := make([]byte, 20)
_, err := conn.Read(vval)
if err != nil {
log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
conn.Close()
return err
}
if bytes.Compare(vval, getverifyval()[:]) != 0 {
log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
conn.Write([]byte("vkey"))
conn.Close()
return err
}
conn.SetReadDeadline(time.Time{})
log.Println("连接新的客户端:", conn.RemoteAddr())
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
s.connList <- conn
return nil
}
func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
//req转为bytes发送给client端
func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error {
raw, err := EncodeRequest(r)
if err != nil {
return err
@@ -133,41 +85,21 @@ func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
return nil
}
func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
val := make([]byte, 4)
_, err := conn.Read(val)
//从client读取出Response
func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error {
flags, err := c.ReadFlag()
if err != nil {
return err
}
flags := string(val)
switch flags {
case "sign":
_, err = conn.Read(val)
case RES_SIGN:
nlen, err := c.GetLen()
if err != nil {
return err
}
nlen := int(binary.LittleEndian.Uint32(val))
if nlen == 0 {
return errors.New("读取客户端长度错误。")
}
log.Println("收到客户端数据,需要读取长度:", nlen)
raw := make([]byte, 0)
buff := make([]byte, 1024)
c := 0
for {
clen, err := conn.Read(buff)
if err != nil && err != io.EOF {
return err
}
raw = append(raw, buff[:clen]...)
c += clen
if c >= nlen {
break
}
}
log.Println("读取完成,长度:", c, "实际raw长度", len(raw))
if c != nlen {
return fmt.Errorf("已读取长度错误,已读取%dbyte需要读取%dbyte。", c, nlen)
raw, err := c.ReadLen(nlen)
if err != nil {
return err
}
resp, err := DecodeResponse(raw)
if err != nil {
@@ -184,10 +116,70 @@ func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
}
w.WriteHeader(resp.StatusCode)
w.Write(bodyBytes)
case "msg0":
return nil
case RES_MSG:
BadRequest(w)
return errors.New("客户端请求出错")
default:
log.Println("无法解析此错误", string(val))
BadRequest(w)
return errors.New("无法解析此错误")
}
return nil
}
type TunnelModeServer struct {
Tunnel
httpPort int
tunnelTarget string
}
func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer {
s := new(TunnelModeServer)
s.tunnelPort = tcpPort
s.httpPort = httpPort
s.tunnelTarget = tunnelTarget
s.tunnelList = make(chan *Conn, 1000)
s.signalList = make(chan *Conn, 10)
return s
}
//开始
func (s *TunnelModeServer) Start() (error) {
err := s.StartTunnel()
if err != nil {
log.Fatalln("开启客户端失败!", err)
return err
}
s.startTunnelServer()
return nil
}
//隧道模式server
func (s *TunnelModeServer) startTunnelServer() {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.httpPort, ""})
if err != nil {
log.Fatalln(err)
}
for {
conn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
go s.process(NewConn(conn))
}
}
//监听连接处理
func (s *TunnelModeServer) process(c *Conn) error {
retry:
if len(s.tunnelList) < 10 { //新建通道
go s.newChan()
}
link := <-s.tunnelList
if _, err := link.WriteHost(s.tunnelTarget); err != nil {
goto retry
}
go io.Copy(link, c)
io.Copy(c, link.conn)
return nil
}