From 3401203087fb4a4b3a53d13866d2678af9b8f74d Mon Sep 17 00:00:00 2001 From: Wizerd Date: Sat, 3 Feb 2024 23:32:06 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E6=94=AF=E6=8C=81proxy=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 2 +- data/config.json | 1 + main.py | 99 +++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 88 insertions(+), 14 deletions(-) diff --git a/Dockerfile b/Dockerfile index a22d120..4313dcf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ RUN apt update && apt install -y jq # RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple # 安装任何所需的依赖项 -RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client +RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client pysocks requests[socks] websocket-client[optional] # 在容器启动时运行 Flask 应用 CMD ["/app/start.sh"] diff --git a/data/config.json b/data/config.json index 244c9a0..1772929 100644 --- a/data/config.json +++ b/data/config.json @@ -3,6 +3,7 @@ "need_log_to_file": "true", "process_workers": 2, "process_threads": 2, + "proxy": "", "upstream_base_url": "", "upstream_api_prefix": "", "backend_container_url": "", diff --git a/main.py b/main.py index 2e0c101..4bb6951 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,9 @@ from io import BytesIO from urllib.parse import urlparse, urlunparse import base64 from fake_useragent import UserAgent +import os +from urllib.parse import urlparse + # 读取配置文件 def load_config(file_path): @@ -92,6 +95,7 @@ logger = logging.getLogger() logger.setLevel(log_level_dict.get(LOG_LEVEL, logging.DEBUG)) + import redis # 假设您已经有一个Redis客户端的实例 @@ -222,13 +226,73 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.5' +VERSION = '0.7.6' # VERSION = 'test' -UPDATE_INFO = '支持缓存GPTS配置' +UPDATE_INFO = '支持proxy参数' # UPDATE_INFO = '【仅供临时测试使用】 ' +# 解析响应中的信息 +def parse_oai_ip_info(): + tmp_ua = ua.random + res = requests.get("https://auth0.openai.com/cdn-cgi/trace", headers={"User-Agent":tmp_ua}, proxies=proxies) + lines = res.text.strip().split("\n") + info_dict = {line.split('=')[0]: line.split('=')[1] for line in lines if '=' in line} + return {key: info_dict[key] for key in ["ip", "loc", "colo", "warp"] if key in info_dict} + with app.app_context(): global gpts_configurations # 移到作用域的最开始 + global proxies + global proxy_type + global proxy_host + global proxy_port + + # 获取环境变量 + proxy_url = CONFIG.get('proxy', None) + + logger.info(f"==========================================") + if proxy_url and proxy_url != '': + parsed_url = urlparse(proxy_url) + scheme = parsed_url.scheme + hostname = parsed_url.hostname + port = parsed_url.port + + # 构建requests支持的代理格式 + if scheme in ['http']: + proxy_address = f"{scheme}://{hostname}:{port}" + proxies = { + 'http': proxy_address, + 'https': proxy_address, + } + proxy_type = scheme + proxy_host = hostname + proxy_port = port + elif scheme in ['socks5']: + proxy_address = f"{scheme}://{hostname}:{port}" + proxies = { + 'http': proxy_address, + 'https': proxy_address, + } + proxy_type = scheme + proxy_host = hostname + proxy_port = port + else: + raise ValueError("Unsupport proxy scheme: " + scheme) + + # 打印当前使用的代理设置 + logger.info(f"Use Proxy: {scheme}://{proxy_host}:{proxy_port}") + else: + # 如果没有设置代理 + proxies = {} + proxy_type = None + http_proxy_host = None + http_proxy_port = None + logger.info("No Proxy") + + ip_info = parse_oai_ip_info() + logger.info(f"The ip you are using to access oai is: {ip_info['ip']}") + logger.info(f"The location of this ip is: {ip_info['loc']}") + logger.info(f"The colo of this ip is: {ip_info['colo']}") + logger.info(f"Is this ip a Warp ip: {ip_info['warp']}") # 输出版本信息 logger.info(f"==========================================") @@ -457,7 +521,7 @@ def upload_file(file_content, mime_type, api_key): 'Content-Type': mime_type, 'x-ms-blob-type': 'BlockBlob' # 添加这个头部 } - put_response = requests.put(upload_url, data=file_content, headers=put_headers) + put_response = requests.put(upload_url, data=file_content, headers=put_headers, proxies=proxies) if put_response.status_code != 201: logger.debug(f"put_response: {put_response.text}") logger.debug(f"put_response status_code: {put_response.status_code}") @@ -621,7 +685,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): tmp_headers = { 'User-Agent': tmp_user_agent } - file_response = requests.get(url=file_url, headers=tmp_headers) + file_response = requests.get(url=file_url, headers=tmp_headers, proxies=proxies) file_content = file_response.content mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip() except Exception as e: @@ -954,7 +1018,7 @@ def replace_sandbox(text, conversation_id, message_id, api_key): if not os.path.exists("./files"): os.makedirs("./files") file_path = f"./files/{filename}" - with requests.get(download_url, stream=True) as r: + with requests.get(download_url, stream=True, proxies=proxies) as r: with open(file_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) @@ -1237,7 +1301,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key if response_format == "url": data_queue.put(('image_url', f"{download_url}")) else: - image_download_response = requests.get(download_url) + image_download_response = requests.get(download_url, proxies=proxies) if image_download_response.status_code == 200: logger.debug(f"下载图片成功") image_data = image_download_response.content @@ -1247,7 +1311,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key else: # 从URL下载图片 # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) + image_download_response = requests.get(download_url, proxies=proxies) # print(f"image_download_response: {image_download_response.text}") if image_download_response.status_code == 200: logger.debug(f"下载图片成功") @@ -1486,7 +1550,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key else: # 从URL下载图片 # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) + image_download_response = requests.get(download_url, proxies=proxies) # print(f"image_download_response: {image_download_response.text}") if image_download_response.status_code == 200: logger.debug(f"下载图片成功") @@ -1705,7 +1769,12 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m on_close = on_close, on_open = on_open) ws.on_open = on_open - ws.run_forever() + # 使用HTTP代理 + if proxy_type: + logger.debug(f"通过代理: {proxy_type}://{proxy_host}:{proxy_port} 连接wss...") + ws.run_forever(http_proxy_host=proxy_host, http_proxy_port=proxy_port, proxy_type=proxy_type) + else: + ws.run_forever() logger.debug(f"end wss...") if context["is_sse"] == True: @@ -1910,10 +1979,14 @@ def register_websocket(api_key): "Authorization": f"Bearer {api_key}" } response = requests.post(url, headers=headers) - response_json = response.json() - logger.debug(f"register_websocket response: {response_json}") - wss_url = response_json.get("wss_url", None) - return wss_url + try: + response_json = response.json() + logger.debug(f"register_websocket response: {response_json}") + wss_url = response_json.get("wss_url", None) + return wss_url + except json.JSONDecodeError: + raise Exception(f"Wss register fail: {response.text}") + return None def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):