[feat] 支持proxy参数

This commit is contained in:
Wizerd
2024-02-03 23:32:06 +08:00
parent ef2015b3f6
commit 3401203087
3 changed files with 88 additions and 14 deletions

View File

@@ -18,7 +18,7 @@ RUN apt update && apt install -y jq
# RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# 安装任何所需的依赖项
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client pysocks requests[socks] websocket-client[optional]
# 在容器启动时运行 Flask 应用
CMD ["/app/start.sh"]

View File

@@ -3,6 +3,7 @@
"need_log_to_file": "true",
"process_workers": 2,
"process_threads": 2,
"proxy": "",
"upstream_base_url": "",
"upstream_api_prefix": "",
"backend_container_url": "",

99
main.py
View File

@@ -24,6 +24,9 @@ from io import BytesIO
from urllib.parse import urlparse, urlunparse
import base64
from fake_useragent import UserAgent
import os
from urllib.parse import urlparse
# 读取配置文件
def load_config(file_path):
@@ -92,6 +95,7 @@ logger = logging.getLogger()
logger.setLevel(log_level_dict.get(LOG_LEVEL, logging.DEBUG))
import redis
# 假设您已经有一个Redis客户端的实例
@@ -222,13 +226,73 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.7.5'
VERSION = '0.7.6'
# VERSION = 'test'
UPDATE_INFO = '支持缓存GPTS配置'
UPDATE_INFO = '支持proxy参数'
# UPDATE_INFO = '【仅供临时测试使用】 '
# 解析响应中的信息
def parse_oai_ip_info():
tmp_ua = ua.random
res = requests.get("https://auth0.openai.com/cdn-cgi/trace", headers={"User-Agent":tmp_ua}, proxies=proxies)
lines = res.text.strip().split("\n")
info_dict = {line.split('=')[0]: line.split('=')[1] for line in lines if '=' in line}
return {key: info_dict[key] for key in ["ip", "loc", "colo", "warp"] if key in info_dict}
with app.app_context():
global gpts_configurations # 移到作用域的最开始
global proxies
global proxy_type
global proxy_host
global proxy_port
# 获取环境变量
proxy_url = CONFIG.get('proxy', None)
logger.info(f"==========================================")
if proxy_url and proxy_url != '':
parsed_url = urlparse(proxy_url)
scheme = parsed_url.scheme
hostname = parsed_url.hostname
port = parsed_url.port
# 构建requests支持的代理格式
if scheme in ['http']:
proxy_address = f"{scheme}://{hostname}:{port}"
proxies = {
'http': proxy_address,
'https': proxy_address,
}
proxy_type = scheme
proxy_host = hostname
proxy_port = port
elif scheme in ['socks5']:
proxy_address = f"{scheme}://{hostname}:{port}"
proxies = {
'http': proxy_address,
'https': proxy_address,
}
proxy_type = scheme
proxy_host = hostname
proxy_port = port
else:
raise ValueError("Unsupport proxy scheme: " + scheme)
# 打印当前使用的代理设置
logger.info(f"Use Proxy: {scheme}://{proxy_host}:{proxy_port}")
else:
# 如果没有设置代理
proxies = {}
proxy_type = None
http_proxy_host = None
http_proxy_port = None
logger.info("No Proxy")
ip_info = parse_oai_ip_info()
logger.info(f"The ip you are using to access oai is: {ip_info['ip']}")
logger.info(f"The location of this ip is: {ip_info['loc']}")
logger.info(f"The colo of this ip is: {ip_info['colo']}")
logger.info(f"Is this ip a Warp ip: {ip_info['warp']}")
# 输出版本信息
logger.info(f"==========================================")
@@ -457,7 +521,7 @@ def upload_file(file_content, mime_type, api_key):
'Content-Type': mime_type,
'x-ms-blob-type': 'BlockBlob' # 添加这个头部
}
put_response = requests.put(upload_url, data=file_content, headers=put_headers)
put_response = requests.put(upload_url, data=file_content, headers=put_headers, proxies=proxies)
if put_response.status_code != 201:
logger.debug(f"put_response: {put_response.text}")
logger.debug(f"put_response status_code: {put_response.status_code}")
@@ -621,7 +685,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
tmp_headers = {
'User-Agent': tmp_user_agent
}
file_response = requests.get(url=file_url, headers=tmp_headers)
file_response = requests.get(url=file_url, headers=tmp_headers, proxies=proxies)
file_content = file_response.content
mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip()
except Exception as e:
@@ -954,7 +1018,7 @@ def replace_sandbox(text, conversation_id, message_id, api_key):
if not os.path.exists("./files"):
os.makedirs("./files")
file_path = f"./files/{filename}"
with requests.get(download_url, stream=True) as r:
with requests.get(download_url, stream=True, proxies=proxies) as r:
with open(file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
@@ -1237,7 +1301,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key
if response_format == "url":
data_queue.put(('image_url', f"{download_url}"))
else:
image_download_response = requests.get(download_url)
image_download_response = requests.get(download_url, proxies=proxies)
if image_download_response.status_code == 200:
logger.debug(f"下载图片成功")
image_data = image_download_response.content
@@ -1247,7 +1311,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key
else:
# 从URL下载图片
# image_data = requests.get(download_url).content
image_download_response = requests.get(download_url)
image_download_response = requests.get(download_url, proxies=proxies)
# print(f"image_download_response: {image_download_response.text}")
if image_download_response.status_code == 200:
logger.debug(f"下载图片成功")
@@ -1486,7 +1550,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key
else:
# 从URL下载图片
# image_data = requests.get(download_url).content
image_download_response = requests.get(download_url)
image_download_response = requests.get(download_url, proxies=proxies)
# print(f"image_download_response: {image_download_response.text}")
if image_download_response.status_code == 200:
logger.debug(f"下载图片成功")
@@ -1705,7 +1769,12 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
on_close = on_close,
on_open = on_open)
ws.on_open = on_open
ws.run_forever()
# 使用HTTP代理
if proxy_type:
logger.debug(f"通过代理: {proxy_type}://{proxy_host}:{proxy_port} 连接wss...")
ws.run_forever(http_proxy_host=proxy_host, http_proxy_port=proxy_port, proxy_type=proxy_type)
else:
ws.run_forever()
logger.debug(f"end wss...")
if context["is_sse"] == True:
@@ -1910,10 +1979,14 @@ def register_websocket(api_key):
"Authorization": f"Bearer {api_key}"
}
response = requests.post(url, headers=headers)
response_json = response.json()
logger.debug(f"register_websocket response: {response_json}")
wss_url = response_json.get("wss_url", None)
return wss_url
try:
response_json = response.json()
logger.debug(f"register_websocket response: {response_json}")
wss_url = response_json.get("wss_url", None)
return wss_url
except json.JSONDecodeError:
raise Exception(f"Wss register fail: {response.text}")
return None
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):