mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-15 07:55:46 +00:00
加密传输,代码优化
This commit is contained in:
165
lib/util.go
165
lib/util.go
@@ -3,18 +3,12 @@ package lib
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@@ -22,19 +16,25 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
disabledRedirect = errors.New("disabled redirect.")
|
||||
bufPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
)
|
||||
//pool 实现
|
||||
type bufType [32 * 1024]byte
|
||||
|
||||
const (
|
||||
COMPRESS_NONE = iota
|
||||
COMPRESS_NONE_ENCODE = iota
|
||||
COMPRESS_NONE_DECODE
|
||||
COMPRESS_SNAPY_ENCODE
|
||||
COMPRESS_SNAPY_DECODE
|
||||
COMPRESS_GZIP_ENCODE
|
||||
COMPRESS_GZIP_DECODE
|
||||
)
|
||||
|
||||
func BadRequest(w http.ResponseWriter) {
|
||||
@@ -134,83 +134,36 @@ func replaceHost(resp []byte) []byte {
|
||||
return []byte(str)
|
||||
}
|
||||
|
||||
func relay(in, out *Conn, compressType int) {
|
||||
buf := make([]byte, 32*1024)
|
||||
func relay(in, out *Conn, compressType int, crypt bool) {
|
||||
fmt.Println(crypt)
|
||||
switch compressType {
|
||||
case COMPRESS_GZIP_ENCODE:
|
||||
//TODO:GZIP压缩存在问题有待解决
|
||||
w := gzip.NewWriter(in)
|
||||
for {
|
||||
n, err := out.Read(buf)
|
||||
if err != nil || err == io.EOF {
|
||||
break
|
||||
}
|
||||
if _, err = w.Write(buf[:n]); err != nil {
|
||||
break
|
||||
}
|
||||
if err = w.Flush(); err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
w.Close()
|
||||
case COMPRESS_SNAPY_ENCODE:
|
||||
io.Copy(NewSnappyConn(in.conn), out)
|
||||
case COMPRESS_GZIP_DECODE:
|
||||
io.Copy(in, NewGzipConn(out.conn))
|
||||
copyBuffer(NewSnappyConn(in.conn, crypt), out)
|
||||
case COMPRESS_SNAPY_DECODE:
|
||||
io.Copy(in, NewSnappyConn(out.conn))
|
||||
default:
|
||||
io.Copy(in, out)
|
||||
copyBuffer(in, NewSnappyConn(out.conn, crypt))
|
||||
case COMPRESS_NONE_ENCODE:
|
||||
copyBuffer(NewCryptConn(in.conn, crypt), out)
|
||||
case COMPRESS_NONE_DECODE:
|
||||
copyBuffer(in, NewCryptConn(out.conn, crypt))
|
||||
}
|
||||
out.Close()
|
||||
in.Close()
|
||||
}
|
||||
|
||||
type Site struct {
|
||||
Host string
|
||||
Url string
|
||||
Port int
|
||||
}
|
||||
type Config struct {
|
||||
SiteList []Site
|
||||
Replace int
|
||||
}
|
||||
type JsonStruct struct {
|
||||
}
|
||||
|
||||
func NewJsonStruct() *JsonStruct {
|
||||
return &JsonStruct{}
|
||||
}
|
||||
func (jst *JsonStruct) Load(filename string) (Config, error) {
|
||||
data, err := ioutil.ReadFile(filename)
|
||||
config := Config{}
|
||||
if err != nil {
|
||||
return config, errors.New("配置文件打开错误")
|
||||
}
|
||||
err = json.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return config, errors.New("配置文件解析错误")
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
//判断压缩方式
|
||||
func getCompressType(compress string) (int, int) {
|
||||
switch compress {
|
||||
case "":
|
||||
return COMPRESS_NONE, COMPRESS_NONE
|
||||
case "gzip":
|
||||
return COMPRESS_GZIP_DECODE, COMPRESS_GZIP_ENCODE
|
||||
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
|
||||
case "snappy":
|
||||
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
|
||||
default:
|
||||
log.Fatalln("数据压缩格式错误")
|
||||
}
|
||||
return COMPRESS_NONE, COMPRESS_NONE
|
||||
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
|
||||
}
|
||||
|
||||
// 简单的一个校验值
|
||||
//简单的一个校验值
|
||||
func getverifyval(vkey string) string {
|
||||
//单客户端模式
|
||||
if *verifyKey != "" {
|
||||
@@ -219,6 +172,7 @@ func getverifyval(vkey string) string {
|
||||
return Md5(vkey)
|
||||
}
|
||||
|
||||
//验证
|
||||
func verify(verifyKeyMd5 string) bool {
|
||||
if *verifyKey != "" && getverifyval(*verifyKey) == verifyKeyMd5 {
|
||||
return true
|
||||
@@ -233,6 +187,7 @@ func verify(verifyKeyMd5 string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
//get key by host from x
|
||||
func getKeyByHost(host string) (h *HostList, t *TaskList, err error) {
|
||||
for _, v := range CsvDb.Hosts {
|
||||
if strings.Contains(host, v.Host) {
|
||||
@@ -245,25 +200,6 @@ func getKeyByHost(host string) (h *HostList, t *TaskList, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
//生成32位md5字串
|
||||
func Md5(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
//生成随机验证密钥
|
||||
func GetRandomString(l int) string {
|
||||
str := "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
bytes := []byte(str)
|
||||
result := []byte{}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for i := 0; i < l; i++ {
|
||||
result = append(result, bytes[r.Intn(len(bytes))])
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
//通过host获取对应的ip地址
|
||||
func Gethostbyname(hostname string) string {
|
||||
if !DomainCheck(hostname) {
|
||||
@@ -310,3 +246,58 @@ func checkAuth(r *http.Request, user, passwd string) bool {
|
||||
}
|
||||
return pair[0] == user && pair[1] == passwd
|
||||
}
|
||||
|
||||
//get bool by str
|
||||
func GetBoolByStr(s string) bool {
|
||||
switch s {
|
||||
case "1", "true":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//get str by bool
|
||||
func GetStrByBool(b bool) string {
|
||||
if b {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
// io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密
|
||||
func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
// If the reader has a WriteTo method, use it to do the copy.
|
||||
// Avoids an allocation and a copy.
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user