From f44f43212faf5220b2fba127938be32f4c8c83b9 Mon Sep 17 00:00:00 2001 From: Yanyutin753 <3254822118@qq.com> Date: Mon, 5 Feb 2024 18:52:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Erefresh=5Ftoken=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E8=BD=ACaccess=5Ftoken?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitattributes | 2 + .idea/.gitignore | 8 + .idea/PandoraToV1Api-main (2).iml | 12 + .idea/inspectionProfiles/Project_Default.xml | 24 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + Dockerfile | 2 +- LICENSE | 21 + Readme.md | 54 +- data/config.json | 4 + docker-compose.yml | 4 +- main.py | 501 ++++++++++++------ 14 files changed, 466 insertions(+), 190 deletions(-) create mode 100644 .gitattributes create mode 100644 .idea/.gitignore create mode 100644 .idea/PandoraToV1Api-main (2).iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 LICENSE diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/PandoraToV1Api-main (2).iml b/.idea/PandoraToV1Api-main (2).iml new file mode 100644 index 0000000..b3e9b48 --- /dev/null +++ b/.idea/PandoraToV1Api-main (2).iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..2777045 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,24 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..0be5fc9 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..ab2306c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 4313dcf..f1ddb66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ RUN apt update && apt install -y jq # RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple # 安装任何所需的依赖项 -RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client pysocks requests[socks] websocket-client[optional] +RUN pip install --no-cache-dir flask flask_apscheduler gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client pysocks requests[socks] websocket-client[optional] # 在容器启动时运行 Flask 应用 CMD ["/app/start.sh"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bbc8ac6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Yanyutin753 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Readme.md b/Readme.md index 4cf29a9..611cc2d 100644 --- a/Readme.md +++ b/Readme.md @@ -2,21 +2,19 @@ > [!IMPORTANT] > -> Respect Zhile大佬, Respect Pandora! +> Respect Zhile大佬 , Respect Wizerd! -为了方便大家将 [Pandora-Next](https://github.com/pandora-next/deploy) 项目与各种其他项目结合完成了本项目。 +感谢pandoraNext和Wizerd的付出,敬礼!!! 本项目支持: -1. 将 Pandora-Next `proxy` 模式下的 `backend-api` 转为 `/v1/chat/completions` 接口,支持流式和非流式响应。 +1. 将 ninja `proxy` 模式下的 `backend-api` 转为 `/v1/chat/completions` 接口,支持流式和非流式响应。 -2. 将 Pandora-Next `proxy` 模式下的 `backend-api` 转为 `/v1/images/generations` 接口 - -如果你想要尝试自己生成Arkose Token从而将PandoraNext的额度消耗降低到`1:4`,你可以看看这个项目:[GenerateArkose](https://github.com/Ink-Osier/GenerateArkose),但是请注意,**该项目并不保证使用该项目生成的Arkose Token不会封号,使用该项目造成的一切后果由使用者自行承担**。 +2. 将 ninja `proxy` 模式下的 `backend-api` 转为 `/v1/images/generations` 接口 如果本项目对你有帮助的话,请点个小星星吧~ -如果有什么在项目的使用过程中的疑惑或需求,欢迎提 `Issue`,或者加入 Community Telegram Channel: [Inker 的魔法世界](https://t.me/InkerWorld) 来和大家一起交流一下~ +如果有什么在项目的使用过程中的疑惑或需求,欢迎加入 Community Telegram Channel: [Inker 的魔法世界](https://t.me/InkerWorld) 来和大家一起交流一下~ ## 更新日志 @@ -57,15 +55,13 @@ ## 注意 > [!CAUTION] -> 1. 本项目的运行需要 Pandora-Next 开启 `auto_conv_arkose:true`,同时请尽量升级最新版本的 Pandora-Next,以确保支持此功能。 +> 1. 本项目的运行需要 ninja > -> 2. 本项目对话次数对Pandora-Next的对话额度消耗比例为: -> - `gpt-4-s`、`gpt-4-mobile`、`GPTS`:`1:14`; -> - `gpt-3.5-turbo`:`1:4`; +> 2. 本项目实际为将来自 `/v1/chat/completions` 的请求转发到ninja的 `/backend-api/conversation` 接口,因此本项目并不支持高并发操作,请不要接入如 `沉浸式翻译` 等高并发项目。 > -> 3. 本项目实际为将来自 `/v1/chat/completions` 的请求转发到Pandora-Next的 `/backend-api/conversation` 接口,因此本项目并不支持高并发操作,请不要接入如 `沉浸式翻译` 等高并发项目。 +> 3. 本项目支持使用apple平台的refresh_token作为请求key. > -> 4. 本项目并不能绕过 OpenAI 和 PandoraNext 官方的限制,只提供便利,不提供绕过。 +> 4. 本项目并不能绕过 OpenAI 和 ninja 官方的限制,只提供便利,不提供绕过。 > > 5. 提问的艺术:当出现项目不能正常运行时,请携带 `DEBUG` 级别的日志在 `Issue` 或者社区群内提问,否则将开启算命模式~ @@ -95,9 +91,9 @@ - `process_threads`: 用于设置线程数,如果不需要设置,可以保持不变,如果需要设置,可以设置为需要设置的值,如果设置为 `1`,则会强制设置为单线程模式。 -- `pandora_base_url`: Pandora-Next 的部署地址,如:`https://pandoranext.com`,注意:不要以 `/` 结尾。可以填写为本项目可以访问到的 PandoraNext 的内网地址。 +- `upstream_base_url`: ninja 的部署地址,如:`https://pandoranext.com`,注意:不要以 `/` 结尾。可以填写为本项目可以访问到的 PandoraNext 的内网地址。 -- `pandora_api_prefix`: PandoraNext Proxy 模式下的 API 前缀 +- `upstream_api_prefix`: PandoraNext Proxy 模式下的 API 前缀 - `backend_container_url`: 用于dalle模型生成图片的时候展示所用,需要设置为使用如 [ChatGPT-Next-Web](https://github.com/ChatGPTNextWebTeam/ChatGPT-Next-Web) 的用户可以访问到的本项目地址,如:`http://1.2.3.4:50011`,同原环境变量中的 `UPLOAD_BASE_URL` @@ -136,7 +132,13 @@ PS. 注意,arkose_urls中的地址需要支持PandoraNext的Arkose Token获取 - `enabled_bing_reference_output`: 用于设置是否开启 Bot 模式下联网插件的引用输出,可选值为:`true`、`false`,默认为 `false`,开启后,将会输出联网插件的引用,仅在 `bot_mode.enabled` 为 `true` 时生效。 - `enabled_plugin_output`: 用于设置是否开启 Bot 模式下插件执行过程的输出,可选值为:`true`、`false`,默认为 `false`,开启后,将会输出插件执行过程的输出,仅在 `bot_mode.enabled` 为 `true` 时生效。 + +- `refresh_ToAccess` + - `enableOai`:用于设置是否使用官网通过refresh_token刷新access_token,仅在 `enableOai` 为 `true` 时生效。 + + - `ninja_refreshToAccess_Url`:用于设置使用ninja来进行使用refresh_token刷新access_token,enableOai为false的时候必填。 + - `redis` - `host`: Redis的ip地址,例如:1.2.3.4,默认是 redis 容器 @@ -174,7 +176,7 @@ PS. 注意,arkose_urls中的地址需要支持PandoraNext的Arkose Token获取 请求方式:`POST` -请求头:正常携带 `Authorization` 和 `Content-Type` 即可,`Authorization` 的值为 `Bearer `,`Content-Type` 的值为 `application/json` +请求头:正常携带 `Authorization` 和 `Content-Type` 即可,`Authorization` 的值为 `Bearer `,`Content-Type` 的值为 `application/json` 请求体格式示例: @@ -317,7 +319,7 @@ services: ports: - "50013:3000" environment: - - OPENAI_API_KEY= + - OPENAI_API_KEY= - BASE_URL= - CUSTOM_MODELS=+gpt-4-s,+gpt-4-mobile,+ @@ -354,21 +356,3 @@ services: #### 关闭 Bot 模式 ![image](https://github.com/Ink-Osier/PandoraToV1Api/assets/133617214/c1d3457f-b912-4572-b4e0-1118b48102d8) - -## 贡献者们 - -> 感谢所有让这个项目变得更好的贡献者们! - -[![Contributors](https://contrib.rocks/image?repo=Ink-Osier/PandoraToV1Api)](https://github.com/Ink-Osier/PandoraToV1Api/graphs/contributors) - -## 平台推荐 - -> [Cloudflare](https://www.cloudflare.com/) - -世界领先的互联网基础设施和安全公司,为您的网站提供 CDN、DNS、DDoS 保护和安全性服务,可以帮助你的项目尽可能避免遭受网络攻击或爬虫行为。 -[![Cloudflare](https://www.cloudflare.com/img/logo-cloudflare.svg)](https://www.cloudflare.com/) - - -## Star 历史 - -![Stargazers over time](https://api.star-history.com/svg?repos=Ink-Osier/PandoraToV1Api&type=Date) \ No newline at end of file diff --git a/data/config.json b/data/config.json index 1772929..aa3a32f 100644 --- a/data/config.json +++ b/data/config.json @@ -24,6 +24,10 @@ "enabled_bing_reference_output": "false", "enabled_plugin_output": "false" }, + "refresh_ToAccess": { + "enableOai":"true", + "ninja_refreshToAccess_Url": "" + }, "redis": { "host": "redis", "port": 6379, diff --git a/docker-compose.yml b/docker-compose.yml index 20e41c2..6c6e89e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: backend-to-api: - image: wizerd/pandora-to-api:latest + image: yangclivia/pandora-to-api:latest restart: always ports: - "50011:33333" @@ -17,4 +17,4 @@ services: command: redis-server --appendonly yes volumes: - ./redis-data:/data - \ No newline at end of file + diff --git a/main.py b/main.py index 49d913c..5b73015 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import base64 from fake_useragent import UserAgent import os from urllib.parse import urlparse +from flask_apscheduler import APScheduler # 读取配置文件 @@ -33,6 +34,7 @@ def load_config(file_path): with open(file_path, 'r', encoding='utf-8') as file: return json.load(file) + CONFIG = load_config('./data/config.json') LOG_LEVEL = CONFIG.get('log_level', 'INFO').upper() @@ -41,6 +43,7 @@ NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true' # 使用 get 方法获取配置项,同时提供默认值 BASE_URL = CONFIG.get('upstream_base_url', '') PROXY_API_PREFIX = CONFIG.get('upstream_api_prefix', '') + if PROXY_API_PREFIX != '': PROXY_API_PREFIX = "/" + PROXY_API_PREFIX UPLOAD_BASE_URL = CONFIG.get('backend_container_url', '') @@ -58,7 +61,12 @@ BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT = BOT_MODE.get('enabled_plugin_output', 'fals BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT = BOT_MODE.get('enabled_plain_image_url_output', 'false').lower() == 'true' -NEED_DELETE_CONVERSATION_AFTER_RESPONSE = CONFIG.get('need_delete_conversation_after_response', 'true').lower() == 'true' +REFRESH_TOACCESS = CONFIG.get('refresh_ToAccess', {}) +REFRESH_TOACCESS_ENABLEOAI = REFRESH_TOACCESS.get('enableOai', 'true').lower() == 'true' +REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL = REFRESH_TOACCESS.get('ninja_refreshToAccess_Url', '') + +NEED_DELETE_CONVERSATION_AFTER_RESPONSE = CONFIG.get('need_delete_conversation_after_response', + 'true').lower() == 'true' USE_OAIUSERCONTENT_URL = CONFIG.get('use_oaiusercontent_url', 'false').lower() == 'true' @@ -79,6 +87,10 @@ REDIS_CONFIG_DB = REDIS_CONFIG.get('db', 0) REDIS_CONFIG_POOL_SIZE = REDIS_CONFIG.get('pool_size', 10) REDIS_CONFIG_POOL_TIMEOUT = REDIS_CONFIG.get('pool_timeout', 30) +# 定义全部变量,用于缓存refresh_token和access_token +# 其中refresh_token 为 key +# access_token 为 value +refresh_dict = {} # 设置日志级别 log_level_dict = { @@ -94,17 +106,15 @@ log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] - %(message)s') logger = logging.getLogger() logger.setLevel(log_level_dict.get(LOG_LEVEL, logging.DEBUG)) - - import redis # 假设您已经有一个Redis客户端的实例 redis_client = redis.StrictRedis(host=REDIS_CONFIG_HOST, - port=REDIS_CONFIG_PORT, - password=REDIS_CONFIG_PASSWORD, - db=REDIS_CONFIG_DB, - retry_on_timeout=True - ) + port=REDIS_CONFIG_PORT, + password=REDIS_CONFIG_PASSWORD, + db=REDIS_CONFIG_DB, + retry_on_timeout=True + ) # 如果环境变量指示需要输出到文件 if NEED_LOG_TO_FILE: @@ -121,6 +131,7 @@ logger.addHandler(stream_handler) # 创建FakeUserAgent对象 ua = UserAgent() + def generate_unique_id(prefix): # 生成一个随机的 UUID random_uuid = uuid.uuid4() @@ -141,11 +152,13 @@ def find_model_config(model_name): return config return None + # 从 gpts.json 读取配置 def load_gpts_config(file_path): with open(file_path, 'r', encoding='utf-8') as file: return json.load(file) + # 根据 ID 发送请求并获取配置信息 def fetch_gizmo_info(base_url, proxy_api_prefix, model_id): url = f"{base_url}{proxy_api_prefix}/backend-api/gizmos/{model_id}" @@ -160,6 +173,7 @@ def fetch_gizmo_info(base_url, proxy_api_prefix, model_id): else: return None + # gpts_configurations = [] # 将配置添加到全局列表 @@ -192,12 +206,13 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data): 'config': gizmo_info }) + def generate_gpts_payload(model, messages): model_config = find_model_config(model) if model_config: gizmo_info = model_config['config'] gizmo_id = gizmo_info['gizmo']['id'] - + payload = { "action": "next", "messages": messages, @@ -217,11 +232,13 @@ def generate_gpts_payload(model, messages): else: return None + # 创建 Flask 应用 app = Flask(__name__) CORS(app, resources={r"/images/*": {"origins": "*"}}) - - +scheduler = APScheduler() +scheduler.init_app(app) +scheduler.start() # PANDORA_UPLOAD_URL = 'files.pandoranext.com' @@ -229,23 +246,26 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) VERSION = '0.7.6' # VERSION = 'test' UPDATE_INFO = '支持proxy参数' + + # UPDATE_INFO = '【仅供临时测试使用】 ' # 解析响应中的信息 def parse_oai_ip_info(): tmp_ua = ua.random - res = requests.get("https://auth0.openai.com/cdn-cgi/trace", headers={"User-Agent":tmp_ua}, proxies=proxies) + res = requests.get("https://auth0.openai.com/cdn-cgi/trace", headers={"User-Agent": tmp_ua}, proxies=proxies) lines = res.text.strip().split("\n") info_dict = {line.split('=')[0]: line.split('=')[1] for line in lines if '=' in line} return {key: info_dict[key] for key in ["ip", "loc", "colo", "warp"] if key in info_dict} + with app.app_context(): global gpts_configurations # 移到作用域的最开始 global proxies global proxy_type global proxy_host global proxy_port - + # 获取环境变量 proxy_url = CONFIG.get('proxy', None) @@ -310,9 +330,6 @@ with app.app_context(): logger.info(f"enabled_bing_reference_output: {BOT_MODE_ENABLED_BING_REFERENCE_OUTPUT}") logger.info(f"enabled_plugin_output: {BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT}") - - - if not BASE_URL: raise Exception('upstream_base_url is not set') else: @@ -329,14 +346,13 @@ with app.app_context(): if not os.path.exists('./files'): os.makedirs('./files') - if not UPLOAD_BASE_URL: if USE_OAIUSERCONTENT_URL: logger.info("backend_container_url 未设置,将使用 oaiusercontent.com 作为图片域名") else: logger.warning("backend_container_url 未设置,图片生成功能将无法正常使用") - - + + else: logger.info(f"backend_container_url: {UPLOAD_BASE_URL}") @@ -355,7 +371,7 @@ with app.app_context(): logger.info(f'绘图接口 URI: /{API_PREFIX}/v1/images/generations') logger.info(f"need_delete_conversation_after_response: {NEED_DELETE_CONVERSATION_AFTER_RESPONSE}") - + logger.info(f"use_oaiusercontent_url: {USE_OAIUSERCONTENT_URL}") logger.info(f"use_pandora_file_server: False") @@ -369,8 +385,6 @@ with app.app_context(): logger.info(f"==========================================") - - # 更新 gpts_configurations 列表,支持多个映射 gpts_configurations = [] for name in GPT_4_S_New_Names: @@ -389,7 +403,6 @@ with app.app_context(): "ori_name": "gpt-3.5-turbo" }) - logger.info(f"GPTS 配置信息") # 加载配置并添加到全局列表 @@ -423,7 +436,7 @@ def get_token(): full_url = f"{url}/api/arkose/token" payload = {'type': 'gpt-4'} - + try: response = requests.post(full_url, data=payload) if response.status_code == 200: @@ -442,22 +455,25 @@ def get_token(): raise Exception("获取 arkose token 失败") return None + import os + def get_image_dimensions(file_content): with Image.open(BytesIO(file_content)) as img: return img.width, img.height + def determine_file_use_case(mime_type): multimodal_types = ["image/jpeg", "image/webp", "image/png", "image/gif"] - my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html", + my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/json", "text/javascript", "application/pdf", + "application/json", "text/javascript", "application/pdf", "text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh", "text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation", - "text/x-c++", "application/x-latext", "text/markdown", "text/plain", + "text/x-c++", "application/x-latext", "text/markdown", "text/plain", "text/x-ruby", "text/x-script.python"] - + if mime_type in multimodal_types: return "multimodal" elif mime_type in my_files_types: @@ -465,6 +481,7 @@ def determine_file_use_case(mime_type): else: return "ace_upload" + def upload_file(file_content, mime_type, api_key): logger.debug("文件上传开始") @@ -487,8 +504,6 @@ def upload_file(file_content, mime_type, api_key): file_name = f"{sha256_hash}{file_extension}" logger.debug(f"文件名: {file_name}") - - logger.debug(f"Use Case: {determine_file_use_case(mime_type)}") if determine_file_use_case(mime_type) == "ace_upload": @@ -547,6 +562,7 @@ def upload_file(file_content, mime_type, api_key): "height": height } + def get_file_metadata(file_content, mime_type, api_key): sha256_hash = hashlib.sha256(file_content).hexdigest() logger.debug(f"sha256_hash: {sha256_hash}") @@ -626,6 +642,7 @@ def get_file_extension(mime_type): } return extension_mapping.get(mime_type, "") + my_files_types = [ "text/x-php", "application/msword", "text/x-c", "text/html", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", @@ -636,6 +653,7 @@ my_files_types = [ "text/x-ruby", "text/x-script.python" ] + # 定义发送请求的函数 def send_text_prompt_and_get_response(messages, api_key, stream, model): url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation" @@ -738,7 +756,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): } formatted_messages.append(formatted_message) logger.critical(f"formatted_message: {formatted_message}") - + else: # 处理单个文本消息的情况 formatted_message = { @@ -767,11 +785,11 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): "action": "next", "messages": formatted_messages, "parent_message_id": str(uuid.uuid4()), - "model":"gpt-4", + "model": "gpt-4", "timezone_offset_min": -480, - "suggestions":[], + "suggestions": [], "history_and_training_disabled": False, - "conversation_mode":{"kind":"primary_assistant"},"force_paragen":False,"force_rate_limit":False + "conversation_mode": {"kind": "primary_assistant"}, "force_paragen": False, "force_rate_limit": False } elif ori_model_name == 'gpt-4-mobile': payload = { @@ -779,13 +797,17 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): "action": "next", "messages": formatted_messages, "parent_message_id": str(uuid.uuid4()), - "model":"gpt-4-mobile", + "model": "gpt-4-mobile", "timezone_offset_min": -480, - "suggestions":["Give me 3 ideas about how to plan good New Years resolutions. Give me some that are personal, family, and professionally-oriented.","Write a text asking a friend to be my plus-one at a wedding next month. I want to keep it super short and casual, and offer an out.","Design a database schema for an online merch store.","Compare Gen Z and Millennial marketing strategies for sunglasses."], + "suggestions": [ + "Give me 3 ideas about how to plan good New Years resolutions. Give me some that are personal, family, and professionally-oriented.", + "Write a text asking a friend to be my plus-one at a wedding next month. I want to keep it super short and casual, and offer an out.", + "Design a database schema for an online merch store.", + "Compare Gen Z and Millennial marketing strategies for sunglasses."], "history_and_training_disabled": False, - "conversation_mode":{"kind":"primary_assistant"},"force_paragen":False,"force_rate_limit":False + "conversation_mode": {"kind": "primary_assistant"}, "force_paragen": False, "force_rate_limit": False } - elif ori_model_name =='gpt-3.5-turbo': + elif ori_model_name == 'gpt-3.5-turbo': payload = { # 构建 payload "action": "next", @@ -799,13 +821,13 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): "Come up with 5 concepts for a retro-style arcade game.", "I have a photoshoot tomorrow. Can you recommend me some colors and outfit options that will look good on camera?" ], - "history_and_training_disabled":False, - "arkose_token":None, + "history_and_training_disabled": False, + "arkose_token": None, "conversation_mode": { "kind": "primary_assistant" }, - "force_paragen":False, - "force_rate_limit":False + "force_paragen": False, + "force_rate_limit": False } else: payload = generate_gpts_payload(model, formatted_messages) @@ -824,6 +846,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): # print(response) return response + def delete_conversation(conversation_id, api_key): logger.info(f"准备删除的会话id: {conversation_id}") if not NEED_DELETE_CONVERSATION_AFTER_RESPONSE: @@ -842,8 +865,11 @@ def delete_conversation(conversation_id, api_key): else: logger.error(f"PATCH 请求失败: {response.text}") + from PIL import Image import io + + def save_image(image_data, path='images'): try: # print(f"image_data: {image_data}") @@ -866,15 +892,16 @@ def save_image(image_data, path='images'): logger.error(f"保存图片时出现异常: {e}") - def unicode_to_chinese(unicode_string): # 首先将字符串转换为标准的 JSON 格式字符串 json_formatted_str = json.dumps(unicode_string) # 然后将 JSON 格式的字符串解析回正常的字符串 return json.loads(json_formatted_str) + import re + # 辅助函数:检查是否为合法的引用格式或正在构建中的引用格式 def is_valid_citation_format(text): # 完整且合法的引用格式,允许紧跟另一个起始引用标记 @@ -882,7 +909,7 @@ def is_valid_citation_format(text): return True # 完整且合法的引用格式 - + if re.fullmatch(r'\u3010\d+\u2020(source|\u6765\u6e90)\u3011', text): return True @@ -893,6 +920,7 @@ def is_valid_citation_format(text): # 不合法的格式 return False + # 辅助函数:检查是否为完整的引用格式 # 检查是否为完整的引用格式 def is_complete_citation_format(text): @@ -911,7 +939,8 @@ def replace_complete_citation(text, citations): logger.debug(f"citation: {citation}") if cited_message_idx == int(citation_number): url = citation.get("metadata", {}).get("url", "") - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_BING_REFERENCE_OUTPUT == True)): + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_BING_REFERENCE_OUTPUT == True)): return f"[[{citation_number}]({url})]" else: return "" @@ -941,26 +970,28 @@ def replace_complete_citation(text, citations): if is_potential_citation: replaced_text = replaced_text[:-len(remaining_text)] - return replaced_text, remaining_text, is_potential_citation + def is_valid_sandbox_combined_corrected_final_v2(text): # 更新正则表达式以包含所有合法格式 patterns = [ - r'.*\(sandbox:\/[^)]*\)?', # sandbox 后跟路径,包括不完整路径 - r'.*\(', # 只有 "(" 也视为合法格式 - r'.*\(sandbox(:|$)', # 匹配 "(sandbox" 或 "(sandbox:",确保后面不跟其他字符或字符串结束 - r'.*\(sandbox:.*\n*', # 匹配 "(sandbox:" 后跟任意数量的换行符 + r'.*\(sandbox:\/[^)]*\)?', # sandbox 后跟路径,包括不完整路径 + r'.*\(', # 只有 "(" 也视为合法格式 + r'.*\(sandbox(:|$)', # 匹配 "(sandbox" 或 "(sandbox:",确保后面不跟其他字符或字符串结束 + r'.*\(sandbox:.*\n*', # 匹配 "(sandbox:" 后跟任意数量的换行符 ] # 检查文本是否符合任一合法格式 return any(bool(re.fullmatch(pattern, text)) for pattern in patterns) + def is_complete_sandbox_format(text): # 完整格式应该类似于 (sandbox:/xx/xx/xx 或 (sandbox:/xx/xx) pattern = r'.*\(sandbox\:\/[^)]+\)\n*' # 匹配 "(sandbox:" 后跟任意数量的换行符 return bool(re.fullmatch(pattern, text)) + import urllib.parse from urllib.parse import unquote @@ -1006,10 +1037,10 @@ def replace_sandbox(text, conversation_id, message_id, api_key): def timestamp_filename(filename): # 在文件名前加上当前时间戳 timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - + # 解码URL编码的filename decoded_filename = unquote(filename) - + return f"{timestamp}_{decoded_filename}" def download_file(download_url, filename): @@ -1028,8 +1059,8 @@ def replace_sandbox(text, conversation_id, message_id, api_key): return replaced_text - -def generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, message_id, model): +def generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, + message_id, model): model_config = find_model_config(model) if model_config: gizmo_info = model_config['config'] @@ -1083,8 +1114,10 @@ def generate_actions_allow_payload(author_role, author_name, target_message_id, else: return None + # 定义发送请求的函数 -def send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, conversation_id, model, api_key): +def send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, + conversation_id, model, api_key): url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation" headers = { "Authorization": f"Bearer {api_key}" @@ -1095,7 +1128,8 @@ def send_allow_prompt_and_get_response(message_id, author_role, author_name, tar if model_config: # 检查是否有 ori_name ori_model_name = model_config.get('ori_name', model) - payload = generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, message_id, model) + payload = generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, + conversation_id, message_id, model) token = None payload['arkose_token'] = token logger.debug(f"payload: {payload}") @@ -1106,7 +1140,7 @@ def send_allow_prompt_and_get_response(message_id, author_role, author_name, tar logger.debug(f"payload: {payload}") logger.info(f"继续请求上游接口") try: - response = requests.post(url, headers=headers, json=payload, stream=True, verify=False, timeout=30) + response = requests.post(url, headers=headers, json=payload, stream=True, verify=False, timeout=30) logger.info(f"成功与上游接口建立连接") # print(response) return response @@ -1114,7 +1148,12 @@ def send_allow_prompt_and_get_response(message_id, author_role, author_name, tar # 处理超时情况 logger.error("请求超时") -def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text): + +def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, + response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, + last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, + file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, + all_new_text): # print(f"data_json: {data_json}") message = data_json.get("message", {}) @@ -1160,7 +1199,8 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key break conversation_id = data_json.get("conversation_id", "") - upstream_response = send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, conversation_id, model, api_key) + upstream_response = send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, + operation_hash, conversation_id, model, api_key) if upstream_response == None: complete_data = 'data: [DONE]\n\n' logger.info(f"会话超时") @@ -1220,7 +1260,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key { "index": 0, "delta": { - "content": ''.join("```\n{\n\"error\": \""+ tmp_message +"\"\n}\n```") + "content": ''.join("```\n{\n\"error\": \"" + tmp_message + "\"\n}\n```") }, "finish_reason": None } @@ -1228,14 +1268,13 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key } q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' data_queue.put(q_data) - q_data = complete_data - data_queue.put(('all_new_text', "```\n{\n\"error\": \""+ tmp_message +"\"\n}```")) + data_queue.put(('all_new_text', "```\n{\n\"error\": \"" + tmp_message + "\"\n}```")) data_queue.put(q_data) last_data_time[0] = time.time() return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - + logger.info(f"action确认事件处理成功, 上游响应数据结构类型: {type(upstream_response)}") upstream_response_json = upstream_response.json() @@ -1256,7 +1295,6 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, upstream_response_id - if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": # 如果是用户发来的消息,直接舍弃 return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None @@ -1271,7 +1309,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key new_text = "" is_img_message = False parts = content.get("parts", []) - for part in parts: + for part in parts: try: # print(f"part: {part}") # print(f"part type: {part.get('content_type')}") @@ -1291,7 +1329,8 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key download_url = image_response.json().get('download_url') logger.debug(f"download_url: {download_url}") if USE_OAIUSERCONTENT_URL == True: - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): new_text = f"\n![image]({download_url})\n[下载链接]({download_url})\n" if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: if all_new_text != "": @@ -1323,7 +1362,8 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key # 使用base64编码图片 image_base64 = base64.b64encode(image_data).decode('utf-8') data_queue.put(('image_url', image_base64)) - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: if all_new_text != "": @@ -1343,7 +1383,6 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key logger.error(f"获取图片下载链接失败: {image_response.text}") except: pass - if is_img_message == False: # print(f"data_json: {data_json}") @@ -1364,7 +1403,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key last_full_code = full_code # 更新完整代码以备下次比较 if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: new_text = "" - + elif last_content_type == "code" and content_type != "code" and content_type != None: full_code = ''.join(content.get("text", "")) new_text = "\n```\n" + full_code[len(last_full_code):] @@ -1384,7 +1423,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key last_full_code = full_code # 更新完整代码以备下次比较 if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: new_text = "" - + else: # 只获取新的 parts parts = content.get("parts", []) @@ -1409,12 +1448,13 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key if is_complete_citation_format(citation_buffer): # 替换完整的引用格式 - replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) + replaced_text, remaining_text, is_potential_citation = replace_complete_citation( + citation_buffer, citations) # print(replaced_text) # 输出替换后的文本 - + new_text = replaced_text - - if(is_potential_citation): + + if (is_potential_citation): citation_buffer = remaining_text else: citation_accumulating = False @@ -1432,7 +1472,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key if "(" in new_text and not file_output_accumulating and not citation_accumulating: file_output_accumulating = True file_output_buffer = file_output_buffer + new_text - + logger.debug(f"开始积累文件输出: {file_output_buffer}") logger.debug(f"file_output_buffer: {file_output_buffer}") logger.debug(f"new_text: {new_text}") @@ -1463,7 +1503,6 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key # Python 工具执行输出特殊处理 if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: - full_code_result = ''.join(content.get("text", "")) new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] @@ -1484,7 +1523,8 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key new_text = "" tmp_new_text = new_text if execution_output_image_url_buffer != "": - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") logger.debug(f"BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") new_text = tmp_new_text + f"![image]({execution_output_image_url_buffer})\n[下载链接]({execution_output_image_url_buffer})\n" @@ -1493,9 +1533,9 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key logger.debug(f"BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") new_text = tmp_new_text + f"图片链接:{execution_output_image_url_buffer}\n" execution_output_image_url_buffer = "" - + if content_type == "code": - new_text = new_text + "\n```\n" + new_text = new_text + "\n```\n" # print(f"full_code_result: {full_code_result}") # print(f"last_full_code_result: {last_full_code_result}") # print(f"new_text: {new_text}") @@ -1512,7 +1552,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key last_full_code_result = full_code_result # 其余Action执行输出特殊处理 - if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: + if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: new_text = "" if last_content_type == "code": if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: @@ -1546,7 +1586,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key logger.debug(f"download_url: {download_url}") if USE_OAIUSERCONTENT_URL == True: execution_output_image_url_buffer = download_url - + else: # 从URL下载图片 # image_data = requests.get(download_url).content @@ -1557,7 +1597,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key image_data = image_download_response.content today_image_url = save_image(image_data) # 保存图片,并获取文件名 execution_output_image_url_buffer = f"{UPLOAD_BASE_URL}/{today_image_url}" - + else: logger.error(f"下载图片失败: {image_download_response.text}") @@ -1586,7 +1626,7 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key "choices": [ { "index": 0, - "delta": {"role":"assistant"}, + "delta": {"role": "assistant"}, "finish_reason": None } ] @@ -1620,7 +1660,6 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key tmp_t = new_text.replace('\n', '\\n') logger.info(f"Send: {tmp_t}") - # if new_text != None: q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' data_queue.put(q_data) @@ -1629,10 +1668,13 @@ def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + import websocket import base64 -def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): + +def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, + messages): headers = { "Sec-Ch-Ua-Mobile": "?0", "User-Agent": ua.random @@ -1650,7 +1692,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m "citation_buffer": "", "citation_accumulating": False, "file_output_buffer": "", - "file_output_accumulating": False, + "file_output_accumulating": False, "execution_output_image_url_buffer": "", "execution_output_image_id_buffer": "", "is_sse": False, @@ -1682,7 +1724,31 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m data_json = json.loads(complete_data.replace('data: ', '')) logger.debug(f"data_json: {data_json}") - context["all_new_text"], context["first_output"], context["last_full_text"], context["last_full_code"], context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context["citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context["file_output_accumulating"], context["execution_output_image_url_buffer"], context["execution_output_image_id_buffer"], allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, context["timestamp"], context["first_output"], context["last_full_text"], context["last_full_code"], context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context["citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context["file_output_accumulating"], context["execution_output_image_url_buffer"], context["execution_output_image_id_buffer"], context["all_new_text"]) + context["all_new_text"], context["first_output"], context["last_full_text"], context["last_full_code"], \ + context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context[ + "citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context[ + "file_output_accumulating"], context["execution_output_image_url_buffer"], context[ + "execution_output_image_id_buffer"], allow_id = process_data_json(data_json, data_queue, stop_event, + last_data_time, api_key, + chat_message_id, model, + response_format, + context["timestamp"], + context["first_output"], + context["last_full_text"], + context["last_full_code"], + context["last_full_code_result"], + context["last_content_type"], + context["conversation_id"], + context["citation_buffer"], + context["citation_accumulating"], + context["file_output_buffer"], + context[ + "file_output_accumulating"], + context[ + "execution_output_image_url_buffer"], + context[ + "execution_output_image_id_buffer"], + context["all_new_text"]) if allow_id: context["response_id"] = allow_id @@ -1696,7 +1762,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m q_data = complete_data data_queue.put(q_data) stop_event.set() - ws.close() + ws.close() def on_error(ws, error): logger.error(error) @@ -1706,7 +1772,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m def on_open(ws): logger.debug(f"on_open: wss") - upstream_response = send_text_prompt_and_get_response(context["messages"], context["api_key"], True, context["model"]) + upstream_response = send_text_prompt_and_get_response(context["messages"], context["api_key"], True, + context["model"]) # upstream_wss_url = None # 检查 Content-Type 是否为 SSE 响应 content_type = upstream_response.headers.get('Content-Type') @@ -1720,7 +1787,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m return else: if upstream_response.status_code != 200: - logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") + logger.error( + f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") complete_data = 'data: [DONE]\n\n' timestamp = context["timestamp"] @@ -1755,6 +1823,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m context["response_id"] = upstream_response_id except json.JSONDecodeError: pass + def run(*args): while True: if stop_event.is_set(): @@ -1764,10 +1833,10 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m logger.debug(f"start wss...") ws = websocket.WebSocketApp(wss_url, - on_message = on_message, - on_error = on_error, - on_close = on_close, - on_open = on_open) + on_message=on_message, + on_error=on_error, + on_close=on_close, + on_open=on_open) ws.on_open = on_open # 使用HTTP代理 if proxy_type: @@ -1779,11 +1848,12 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m logger.debug(f"end wss...") if context["is_sse"] == True: logger.debug(f"process sse...") - old_data_fetcher(context["upstream_response"], data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format) + old_data_fetcher(context["upstream_response"], data_queue, stop_event, last_data_time, api_key, chat_message_id, + model, response_format) - -def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format): +def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, + response_format): all_new_text = "" first_output = True @@ -1820,8 +1890,6 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) - - while 'data:' in buffer and '\n\n' in buffer: end_index = buffer.index('\n\n') + 2 complete_data, buffer = buffer[:end_index], buffer[end_index:] @@ -1830,7 +1898,12 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, data_json = json.loads(complete_data.replace('data: ', '')) logger.debug(f"data_json: {data_json}") # print(f"data_json: {data_json}") - all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text) + all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json( + data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, + response_format, timestamp, first_output, last_full_text, last_full_code, + last_full_code_result, last_content_type, conversation_id, citation_buffer, + citation_accumulating, file_output_buffer, file_output_accumulating, + execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text) except json.JSONDecodeError: # print("JSON 解析错误") logger.info(f"发送数据: {complete_data}") @@ -1862,7 +1935,7 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, # print(f"发送数据: {tmp}") # 累积 new_text all_new_text += citation_buffer - q_data = 'data: ' + json.dumps(new_data) + '\n\n' + q_data = 'data: ' + json.dumps(new_data) + '\n\n' data_queue.put(q_data) last_data_time[0] = time.time() if buffer: @@ -1873,20 +1946,20 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, logger.info(f"最后的缓存数据: {buffer_json}") error_message = buffer_json.get("detail", {}).get("message", "未知错误") error_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": "error", - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n" + error_message + "\n```") - }, - "finish_reason": None - } - ] + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + error_message + "\n```") + }, + "finish_reason": None } + ] + } tmp = 'data: ' + json.dumps(error_data) + '\n\n' logger.info(f"发送最后的数据: {tmp}") # 累积 new_text @@ -1904,22 +1977,22 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, # print("JSON 解析错误") logger.info(f"发送最后的数据: {buffer}") error_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": "error", - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n" + buffer + "\n```") - }, - "finish_reason": None - } - ] + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + buffer + "\n```") + }, + "finish_reason": None } + ] + } tmp = 'data: ' + json.dumps(error_data) + '\n\n' - q_data = tmp + q_data = tmp data_queue.put(q_data) last_data_time[0] = time.time() complete_data = 'data: [DONE]\n\n' @@ -1937,6 +2010,7 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, data_queue.put(q_data) last_data_time[0] = time.time() + def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): all_new_text = "" @@ -1957,22 +2031,24 @@ def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_i file_output_accumulating = False execution_output_image_url_buffer = "" execution_output_image_id_buffer = "" - + wss_url = register_websocket(api_key) # response_json = upstream_response.json() # wss_url = response_json.get("wss_url", None) # logger.info(f"wss_url: {wss_url}") - # 如果存在 wss_url,使用 WebSocket 连接获取数据 + # 如果存在 wss_url,使用 WebSocket 连接获取数据 + + process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, + messages) - process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) - while True: if stop_event.is_set(): logger.info(f"接受到停止信号,停止数据处理线程-外层") - + break + def register_websocket(api_key): url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/register-websocket" headers = { @@ -1989,9 +2065,9 @@ def register_websocket(api_key): return None -def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): +def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): while not stop_event.is_set(): - if time.time() - last_data_time[0] >=1: + if time.time() - last_data_time[0] >= 1: # logger.debug(f"发送保活消息") # 当前时间戳 timestamp = int(time.time()) @@ -2018,8 +2094,10 @@ def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): logger.debug(f"接受到停止信号,停止保活线程") return + import tiktoken + def count_tokens(text, model_name): """ Count the number of tokens for a given text using a specified model. @@ -2039,6 +2117,7 @@ def count_tokens(text, model_name): token_list = encoder.encode(text) return len(token_list) + def count_total_input_words(messages, model): """ Count the total number of words in all messages' content. @@ -2057,8 +2136,74 @@ def count_total_input_words(messages, model): return total_words + +# 官方refresh_token刷新access_token +def oaiGetAccessToken(refresh_token): + url = "https://auth0.openai.com/oauth/token" + headers = { + "Content-Type": "application/json" + } + data = { + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + "grant_type": "refresh_token", + "client_id": "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh", + "refresh_token": refresh_token + } + try: + response = requests.post(url, headers=headers, json=data) + # 如果响应的状态码不是 200,将引发 HTTPError 异常 + response.raise_for_status() + + # 拿到access_token + json_response = response.json() + access_token = json_response.get('access_token') + + # 检查 access_token 是否有效 + if not access_token or not access_token.startswith("eyJhb"): + logger.error("access_token 无效.") + return None + + return access_token + + except requests.HTTPError as http_err: + logger.error(f"HTTP error occurred: {http_err}") + except Exception as err: + logger.error(f"Other error occurred: {err}") + return None + + +# ninja获得access_token +def ninjaGetAccessToken(refresh_token, getAccessTokenUrl): + try: + logger.info("将通过这个网址请求access_token:" + getAccessTokenUrl) + headers = {"Authorization": "Bearer " + refresh_token} + response = requests.post(getAccessTokenUrl, headers=headers) + if not response.ok: + logger.error("Request 失败: " + response.text.strip()) + return None + access_token = None + try: + jsonResponse = response.json() + access_token = jsonResponse.get("access_token") + except json.JSONDecodeError: + logger.exception("Failed to decode JSON response.") + if response.status_code == 200 and access_token and access_token.startswith("eyJhb"): + return access_token + except Exception as e: + logger.exception("获取access token失败.") + return None + + +# 添加缓存 +def add_to_dict(key, value): + global refresh_dict + refresh_dict[key] = value + + import threading import time + + # 定义 Flask 路由 @app.route(f'/{API_PREFIX}/v1/chat/completions' if API_PREFIX else '/v1/chat/completions', methods=['POST']) def chat_completions(): @@ -2075,17 +2220,27 @@ def chat_completions(): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Authorization header is missing or invalid"}), 401 - api_key = auth_header.split(' ')[1] + api_key = auth_header.split(' ')[1] + if not api_key.startswith("eyJhb"): + if api_key in refresh_dict: + api_key = refresh_dict.get(api_key) + else: + if REFRESH_TOACCESS_ENABLEOAI: + refresh_token = api_key + api_key = oaiGetAccessToken(api_key) + else: + api_key = ninjaGetAccessToken(REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL, api_key) + if not api_key.startswith("eyJhb"): + return jsonify({"error": "refresh_token is wrong or refresh_token url is wrong!"}), 401 + add_to_dict(refresh_token, api_key) logger.info(f"api_key: {api_key}") - # upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model) # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" image_urls = [] - # 处理流式响应 def generate(): nonlocal all_new_text # 引用外部变量 @@ -2099,11 +2254,13 @@ def chat_completions(): conversation_id = '' # 启动数据处理线程 - fetcher_thread = threading.Thread(target=data_fetcher, args=(data_queue, stop_event, last_data_time, api_key, chat_message_id, model,"url", messages)) + fetcher_thread = threading.Thread(target=data_fetcher, args=( + data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages)) fetcher_thread.start() # 启动保活线程 - keep_alive_thread = threading.Thread(target=keep_alive, args=(last_data_time, stop_event, data_queue, model, chat_message_id)) + keep_alive_thread = threading.Thread(target=keep_alive, + args=(last_data_time, stop_event, data_queue, model, chat_message_id)) keep_alive_thread.start() try: @@ -2134,9 +2291,9 @@ def chat_completions(): "model": model, "choices": [ { - "delta":{}, - "index":0, - "finish_reason":"stop" + "delta": {}, + "index": 0, + "finish_reason": "stop" } ] } @@ -2158,8 +2315,7 @@ def chat_completions(): # if conversation_id: # # print(f"准备删除的会话id: {conversation_id}") - # delete_conversation(conversation_id, api_key) - + # delete_conversation(conversation_id, api_key) if not stream: # 执行流式响应的生成函数来累积 all_new_text @@ -2199,7 +2355,7 @@ def chat_completions(): # 返回 JSON 响应 return jsonify(response_json) - else: + else: return Response(generate(), mimetype='text/event-stream') @@ -2213,7 +2369,7 @@ def images_generations(): accessible_model_list = get_accessible_model_list() if model not in accessible_model_list: return jsonify({"error": "model is not accessible"}), 401 - + prompt = data.get('prompt', '') prompt = DALLE_PROMPT_PREFIX + prompt @@ -2226,7 +2382,21 @@ def images_generations(): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Authorization header is missing or invalid"}), 401 - api_key = auth_header.split(' ')[1] + api_key = auth_header.split(' ')[1] + if not api_key.startswith("eyJhb"): + if api_key in refresh_dict: + api_key = refresh_dict.get(api_key) + + else: + if REFRESH_TOACCESS_ENABLEOAI: + refresh_token = api_key + api_key = oaiGetAccessToken(api_key) + else: + api_key = ninjaGetAccessToken(REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL, api_key) + if not api_key.startswith("eyJhb"): + return jsonify({"error": "refresh_token is wrong or refresh_token url is wrong!"}), 401 + add_to_dict(refresh_token, api_key) + logger.info(f"api_key: {api_key}") image_urls = [] @@ -2257,11 +2427,13 @@ def images_generations(): conversation_id = '' # 启动数据处理线程 - fetcher_thread = threading.Thread(target=data_fetcher, args=(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) + fetcher_thread = threading.Thread(target=data_fetcher, args=( + data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) fetcher_thread.start() # 启动保活线程 - keep_alive_thread = threading.Thread(target=keep_alive, args=(last_data_time, stop_event, data_queue, model, chat_message_id)) + keep_alive_thread = threading.Thread(target=keep_alive, + args=(last_data_time, stop_event, data_queue, model, chat_message_id)) keep_alive_thread.start() try: @@ -2298,7 +2470,7 @@ def images_generations(): # if conversation_id: # # print(f"准备删除的会话id: {conversation_id}") - # delete_conversation(conversation_id, cookie, x_authorization) + # delete_conversation(conversation_id, cookie, x_authorization) # 执行流式响应的生成函数来累积 all_new_text # 迭代生成器对象以执行其内部逻辑 @@ -2359,6 +2531,7 @@ def options_handler(): logger.info(f"Options Request") return Response(status=200) + @app.route('/', defaults={'path': ''}, methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) @app.route('/', methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) def catch_all(path): @@ -2378,6 +2551,7 @@ def get_image(filename): return "文件不存在哦!", 404 return send_from_directory('images', filename) + @app.route('/files/') @cross_origin() # 使用装饰器来允许跨域请求 def get_file(filename): @@ -2386,7 +2560,30 @@ def get_file(filename): return "文件不存在哦!", 404 return send_from_directory('files', filename) + + +# 内置自动刷新access_token +def updateRefresh_dict(): + success_num = 0 + error_num = 0 + logging.info("开始更新access_token.........") + for key in refresh_dict: + if REFRESH_TOACCESS_ENABLEOAI: + refresh_token = key + access_token = oaiGetAccessToken(key) + else: + access_token = ninjaGetAccessToken(REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL, key) + if not access_token.startswith("eyJhb"): + logger.debug("refresh_token is wrong or refresh_token url is wrong!") + error_num += 1 + add_to_dict(refresh_token, access_token) + success_num += 1 + logging.info("更新成功: " + str(success_num) + ", 失败: " + str(error_num)) + +# 在 APScheduler 中注册你的任务 +# 每天3点自动刷新 +scheduler.add_job(id='updateRefresh_run', func=updateRefresh_dict, trigger='cron', hour=3, minute=0) + # 运行 Flask 应用 if __name__ == '__main__': app.run(host='0.0.0.0') -