加密传输,代码优化

This commit is contained in:
刘河
2019-01-03 01:44:45 +08:00
parent 4dad726129
commit 1d89e7dae2
16 changed files with 725 additions and 208 deletions

View File

@@ -3,18 +3,12 @@ package lib
import (
"bufio"
"bytes"
"compress/gzip"
"crypto/md5"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/http/httputil"
@@ -22,19 +16,25 @@ import (
"regexp"
"strconv"
"strings"
"time"
"sync"
)
var (
disabledRedirect = errors.New("disabled redirect.")
bufPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
)
//pool 实现
type bufType [32 * 1024]byte
const (
COMPRESS_NONE = iota
COMPRESS_NONE_ENCODE = iota
COMPRESS_NONE_DECODE
COMPRESS_SNAPY_ENCODE
COMPRESS_SNAPY_DECODE
COMPRESS_GZIP_ENCODE
COMPRESS_GZIP_DECODE
)
func BadRequest(w http.ResponseWriter) {
@@ -134,83 +134,36 @@ func replaceHost(resp []byte) []byte {
return []byte(str)
}
func relay(in, out *Conn, compressType int) {
buf := make([]byte, 32*1024)
func relay(in, out *Conn, compressType int, crypt bool) {
fmt.Println(crypt)
switch compressType {
case COMPRESS_GZIP_ENCODE:
//TODO:GZIP压缩存在问题有待解决
w := gzip.NewWriter(in)
for {
n, err := out.Read(buf)
if err != nil || err == io.EOF {
break
}
if _, err = w.Write(buf[:n]); err != nil {
break
}
if err = w.Flush(); err != nil {
log.Println(err)
break
}
}
w.Close()
case COMPRESS_SNAPY_ENCODE:
io.Copy(NewSnappyConn(in.conn), out)
case COMPRESS_GZIP_DECODE:
io.Copy(in, NewGzipConn(out.conn))
copyBuffer(NewSnappyConn(in.conn, crypt), out)
case COMPRESS_SNAPY_DECODE:
io.Copy(in, NewSnappyConn(out.conn))
default:
io.Copy(in, out)
copyBuffer(in, NewSnappyConn(out.conn, crypt))
case COMPRESS_NONE_ENCODE:
copyBuffer(NewCryptConn(in.conn, crypt), out)
case COMPRESS_NONE_DECODE:
copyBuffer(in, NewCryptConn(out.conn, crypt))
}
out.Close()
in.Close()
}
type Site struct {
Host string
Url string
Port int
}
type Config struct {
SiteList []Site
Replace int
}
type JsonStruct struct {
}
func NewJsonStruct() *JsonStruct {
return &JsonStruct{}
}
func (jst *JsonStruct) Load(filename string) (Config, error) {
data, err := ioutil.ReadFile(filename)
config := Config{}
if err != nil {
return config, errors.New("配置文件打开错误")
}
err = json.Unmarshal(data, &config)
if err != nil {
return config, errors.New("配置文件解析错误")
}
return config, nil
}
//判断压缩方式
func getCompressType(compress string) (int, int) {
switch compress {
case "":
return COMPRESS_NONE, COMPRESS_NONE
case "gzip":
return COMPRESS_GZIP_DECODE, COMPRESS_GZIP_ENCODE
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
case "snappy":
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
default:
log.Fatalln("数据压缩格式错误")
}
return COMPRESS_NONE, COMPRESS_NONE
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
}
// 简单的一个校验值
//简单的一个校验值
func getverifyval(vkey string) string {
//单客户端模式
if *verifyKey != "" {
@@ -219,6 +172,7 @@ func getverifyval(vkey string) string {
return Md5(vkey)
}
//验证
func verify(verifyKeyMd5 string) bool {
if *verifyKey != "" && getverifyval(*verifyKey) == verifyKeyMd5 {
return true
@@ -233,6 +187,7 @@ func verify(verifyKeyMd5 string) bool {
return false
}
//get key by host from x
func getKeyByHost(host string) (h *HostList, t *TaskList, err error) {
for _, v := range CsvDb.Hosts {
if strings.Contains(host, v.Host) {
@@ -245,25 +200,6 @@ func getKeyByHost(host string) (h *HostList, t *TaskList, err error) {
return
}
//生成32位md5字串
func Md5(s string) string {
h := md5.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
//生成随机验证密钥
func GetRandomString(l int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyz"
bytes := []byte(str)
result := []byte{}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < l; i++ {
result = append(result, bytes[r.Intn(len(bytes))])
}
return string(result)
}
//通过host获取对应的ip地址
func Gethostbyname(hostname string) string {
if !DomainCheck(hostname) {
@@ -310,3 +246,58 @@ func checkAuth(r *http.Request, user, passwd string) bool {
}
return pair[0] == user && pair[1] == passwd
}
//get bool by str
func GetBoolByStr(s string) bool {
switch s {
case "1", "true":
return true
}
return false
}
//get str by bool
func GetStrByBool(b bool) string {
if b {
return "1"
}
return "0"
}
// io.copy的优化版读取buffer长度原为32*1024与snappy不同导致读取出的内容存在差异不利于解密
func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
if rt, ok := dst.(io.ReaderFrom); ok {
return rt.ReadFrom(src)
}
buf := make([]byte, 65535)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}