From 0f90a47da4202e2f7636acf7baf93433b808e8c5 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Thu, 14 Dec 2023 11:00:58 +0800 Subject: [PATCH] =?UTF-8?q?[init]=20=E5=88=9D=E5=A7=8B=E5=8C=96=E9=A1=B9?= =?UTF-8?q?=E7=9B=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 19 ++ Readme.md | 29 +++ docker-compose.yml | 22 ++ main.py | 575 +++++++++++++++++++++++++++++++++++++++++++++ start.sh | 10 + upload.py | 17 ++ 6 files changed, 672 insertions(+) create mode 100644 Dockerfile create mode 100644 Readme.md create mode 100644 docker-compose.yml create mode 100644 main.py create mode 100755 start.sh create mode 100644 upload.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4530c86 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +# 使用官方 Python 运行时作为父镜像 +FROM python:3.9-slim + +# 设置工作目录为 /app +WORKDIR /app + +# 将当前目录内容复制到位于 /app 的容器中 +COPY . /app + +# 设置环境变量 +ENV PYTHONUNBUFFERED=1 + +RUN chmod +x /app/start.sh + +# 安装任何所需的依赖项 +RUN pip install --no-cache-dir flask gunicorn requests Pillow + +# 在容器启动时运行 Flask 应用 +CMD ["/app/start.sh"] diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..599bbb6 --- /dev/null +++ b/Readme.md @@ -0,0 +1,29 @@ +# 项目简介 + +为了方便大家将 (Pandora-Next)[https://github.com/pandora-next/deploy] 项目与各种其他项目结合完成了本项目。 + +# Docker-Compose 部署 + +仓库内已包含相关文件和目录,拉到本地后修改 docker-compose.yml 文件里的环境变量后运行`docker-compose up -d`即可。 + +# 环境变量说明: + +- UPLOAD_BASE_URL 用于dalle模型生成图片的时候展示所用,需要设置为使用如chatgpt-next-web的用户可以访问到的 Uploader 容器地址,如:http://127.0.0.1:50012 + +# 示例 + +以ChatGPT-Next-Web项目的docker-compose部署为例,这里提供一个简单的部署配置文件示例: + +``` +version: '3' +services: + chatgpt-next-web: + image: yidadaa/chatgpt-next-web + ports: + - "50013:3000" + environment: + - OPENAI_API_KEY= + - BASE_URL= + - CUSTOM_MODELS=+gpt-4-s,+gpt-4-classic + +``` \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..486954d --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,22 @@ +version: '3' + +services: + backend-to-api: + build: . + ports: + - "50011:33333" + environment: + - BASE_URL= + - PROXY_API_PREFIX= + - UPLOAD_BASE_URL= + volumes: + - .:/app + + uploader: + build: . + entrypoint: ["python3", "/app/upload.py"] + volumes: + - .:/app + ports: + - "50012:23333" + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..05bbbd3 --- /dev/null +++ b/main.py @@ -0,0 +1,575 @@ +# 导入所需的库 +from flask import Flask, request, jsonify, Response +import requests +import uuid +import json +import time +from datetime import datetime +import os + + + + +# 创建 Flask 应用 +app = Flask(__name__) + + +# 添加环境变量配置 +BASE_URL = os.getenv('BASE_URL', '') +PROXY_API_PREFIX = os.getenv('PROXY_API_PREFIX', '') +UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '') + +with app.app_context(): + if not BASE_URL: + raise Exception('BASE_URL is not set') + else: + print(f"BASE_URL: {BASE_URL}") + if not PROXY_API_PREFIX: + raise Exception('PROXY_API_PREFIX is not set') + else: + print(f"PROXY_API_PREFIX: {PROXY_API_PREFIX}") + +# 定义获取 token 的函数 +def get_token(): + url = f"{BASE_URL}/{PROXY_API_PREFIX}/api/arkose/token" + payload = {'type': 'gpt-4'} + response = requests.post(url, data=payload) + if response.status_code == 200: + return response.json().get('token') + else: + return None + +import os + +# 定义发送请求的函数 +def send_text_prompt_and_get_response(messages, api_key, stream, model): + url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/conversation" + headers = { + "Authorization": f"Bearer {api_key}" + } + + formatted_messages = [] + for message in messages: + message_id = str(uuid.uuid4()) + formatted_message = { + "id": message_id, + "author": {"role": message.get("role")}, + "content": {"content_type": "text", "parts": [message.get("content")]}, + "metadata": {} + } + formatted_messages.append(formatted_message) + + payload = {} + + print(f"model: {model}") + + if model == 'gpt-4-classic': + payload = { + # 构建 payload + "action": "next", + "messages": formatted_messages, + "parent_message_id": str(uuid.uuid4()), + "model": "gpt-4-gizmo", + "timezone_offset_min": -480, + "history_and_training_disabled": False, + "conversation_mode": { + "gizmo": { + "gizmo": { + "id": "g-YyyyMT9XH", + "organization_id": "org-OROoM5KiDq6bcfid37dQx4z4", + "short_url": "g-YyyyMT9XH-chatgpt-classic", + "author": { + "user_id": "user-u7SVk5APwT622QC7DPe41GHJ", + "display_name": "ChatGPT", + "link_to":None, + "selected_display": "name", + "is_verified":True + }, + "voice": { + "id": "ember" + }, + "workspace_id":None, + "model":None, + "instructions":None, + "settings":None, + "display": { + "name": "ChatGPT Classic", + "description": "The latest version of GPT-4 with no additional capabilities", + "welcome_message": "Hello", + "prompt_starters":None, + "profile_picture_url": "", + "categories": [] + }, + "share_recipient": "marketplace", + "updated_at": "2023-11-26T17:46:07.341305+00:00", + "last_interacted_at": "2023-12-11T09:49:34.943245+00:00", + "tags": [ + "public", + "first_party" + ], + "version":None, + "live_version":None, + "training_disabled":None, + "allowed_sharing_recipients":None, + "review_info":None, + "appeal_info":None, + "vanity_metrics":None + }, + "tools": [], + "files": [], + "product_features": { + "attachments": { + "type": "retrieval", + "accepted_mime_types": [ + "text/x-script.python", + "application/x-latext", + "text/x-c++", + "text/javascript", + "text/x-java", + "text/x-typescript", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "text/x-csharp", + "text/plain", + "application/pdf", + "text/x-sh", + "text/markdown", + "text/x-c", + "text/x-ruby", + "text/x-tex", + "text/x-php", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/json", + "text/html", + "application/msword" + ], + "image_mime_types": [ + "image/webp", + "image/jpeg", + "image/png", + "image/gif" + ], + "can_accept_all_mime_types":True + } + } + }, + "kind": "gizmo_interaction", + "gizmo_id": "g-YyyyMT9XH" + }, + "force_paragen":False, + "force_rate_limit":False + } + elif model == 'gpt-4-s': + payload = { + # 构建 payload + "action": "next", + "messages": formatted_messages, + "parent_message_id": str(uuid.uuid4()), + "model":"gpt-4", + "timezone_offset_min": -480, + "suggestions":[], + "history_and_training_disabled": False, + "conversation_mode":{"kind":"primary_assistant"},"force_paragen":False,"force_rate_limit":False + } + response = requests.post(url, headers=headers, json=payload, stream=stream) + # print(response) + return response + +def delete_conversation(conversation_id, api_key): + print(f"[{datetime.now()}] 准备删除的会话id: {conversation_id}") + if conversation_id: + patch_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/conversation/{conversation_id}" + patch_headers = { + "Authorization": f"Bearer {api_key}", + } + patch_data = {"is_visible": False} + response = requests.patch(patch_url, headers=patch_headers, json=patch_data) + + if response.status_code == 200: + print(f"[{datetime.now()}] 删除会话 {conversation_id} 成功") + else: + print(f"[{datetime.now()}] PATCH 请求失败: {response.text}") + +from PIL import Image +import io +def save_image(image_data, path='./images'): + if not os.path.exists(path): + os.makedirs(path) + current_time = datetime.now().strftime('%Y%m%d%H%M%S') + filename = f'image_{current_time}.png' + print(f"filename: {filename}") + # 使用 PIL 打开图像数据 + with Image.open(io.BytesIO(image_data)) as image: + # 保存为 PNG 格式 + image.save(os.path.join(path, filename), 'PNG') + + print(f"保存图片成功: {filename}") + + return os.path.join(path, filename) + + +import re + +# 辅助函数:检查是否为合法的引用格式或正在构建中的引用格式 +def is_valid_citation_format(text): + # 完整且合法的引用格式,允许紧跟另一个起始引用标记 + if re.fullmatch(r'\u3010\d+\u2020(source|\u6765\u6e90)\u3011\u3010?', text): + return True + + # 完整且合法的引用格式 + + if re.fullmatch(r'\u3010\d+\u2020(source|\u6765\u6e90)\u3011', text): + return True + + # 合法的部分构建格式 + if re.fullmatch(r'\u3010(\d+)?(\u2020(source|\u6765\u6e90)?)?', text): + return True + + # 不合法的格式 + return False + +# 辅助函数:检查是否为完整的引用格式 +# 检查是否为完整的引用格式 +def is_complete_citation_format(text): + return bool(re.fullmatch(r'\u3010\d+\u2020(source|\u6765\u6e90)\u3011\u3010?', text)) + + +# 替换完整的引用格式 +def replace_complete_citation(text, citations): + def replace_match(match): + citation_number = match.group(1) + for citation in citations: + cited_message_idx = citation.get('metadata', {}).get('extra', {}).get('cited_message_idx') + print(f"cited_message_idx: {cited_message_idx}") + print(f"citation_number: {citation_number}") + print(f"is citation_number == cited_message_idx: {cited_message_idx == int(citation_number)}") + print(f"citation: {citation}") + if cited_message_idx == int(citation_number): + url = citation.get("metadata", {}).get("url", "") + return f"[[{citation_number}]({url})]" + return match.group(0) # 如果没有找到对应的引用,返回原文本 + + # 使用 finditer 找到第一个匹配项 + match_iter = re.finditer(r'\u3010(\d+)\u2020(source|\u6765\u6e90)\u3011', text) + first_match = next(match_iter, None) + + if first_match: + start, end = first_match.span() + replaced_text = text[:start] + replace_match(first_match) + text[end:] + remaining_text = text[end:] + else: + replaced_text = text + remaining_text = "" + + is_potential_citation = is_valid_citation_format(remaining_text) + + # 替换掉replaced_text末尾的remaining_text + + print(f"replaced_text: {replaced_text}") + print(f"remaining_text: {remaining_text}") + print(f"is_potential_citation: {is_potential_citation}") + if is_potential_citation: + replaced_text = replaced_text[:-len(remaining_text)] + + + return replaced_text, remaining_text, is_potential_citation + +accessable_model_list = ['gpt-4-classic', 'gpt-4-s'] +# 定义 Flask 路由 +@app.route('/v1/chat/completions', methods=['POST']) +def chat_completions(): + print(f"[{datetime.now()}] New Request") + data = request.json + messages = data.get('messages') + model = data.get('model') + if model not in accessable_model_list: + return jsonify({"error": "model is not accessable"}), 401 + stream = data.get('stream', False) + + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Bearer '): + return jsonify({"error": "Authorization header is missing or invalid"}), 401 + api_key = auth_header.split(' ')[1] + print(f"api_key: {api_key}") + + + upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model) + + if not stream: + return Response(upstream_response) + else: + # 处理流式响应 + def generate(): + buffer = "" + last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 + last_full_code = "" + last_full_code_result = "" + last_content_type = None # 用于记录上一个消息的内容类型 + conversation_id = '' + citation_buffer = "" + citation_accumulating = False + for chunk in upstream_response.iter_content(chunk_size=1024): + if chunk: + buffer += chunk.decode('utf-8') + while 'data:' in buffer and '\n\n' in buffer: + end_index = buffer.index('\n\n') + 2 + complete_data, buffer = buffer[:end_index], buffer[end_index:] + # 解析 data 块 + try: + data_json = json.loads(complete_data.replace('data: ', '')) + message = data_json.get("message", {}) + content = message.get("content", {}) + role = message.get("author", {}).get("role") + content_type = content.get("content_type") + print(f"content_type: {content_type}") + print(f"last_content_type: {last_content_type}") + metadata = {} + citations = [] + try: + metadata = message.get("metadata", {}) + citations = metadata.get("citations", []) + except: + pass + name = message.get("author", {}).get("name") + if role == "user": + # 如果是用户发来的消息,直接舍弃 + continue + try: + conversation_id = data_json.get("conversation_id") + print(f"conversation_id: {conversation_id}") + except: + pass + # 只获取新的部分 + new_text = "" + is_img_message = False + parts = content.get("parts", []) + for part in parts: + try: + if part.get('content_type') == 'image_asset_pointer': + print(f"content_type: {content_type}") + is_img_message = True + asset_pointer = part.get('asset_pointer').replace('file-service://', '') + print(f"asset_pointer: {asset_pointer}") + image_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + print(f"download_url: {download_url}") + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + print(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" + else: + print(f"下载图片失败: {image_download_response.text}") + if last_content_type == "code": + new_text = "\n```\n" + new_text + print(f"new_text: {new_text}") + is_img_message = True + else: + print(f"获取图片下载链接失败: {image_response.text}") + except: + pass + + + if is_img_message == False: + if content_type == "multimodal_text" and last_content_type == "code": + new_text = "\n```\n" + content.get("text", "") + elif role == "tool" and name == "dalle.text2im": + print(f"无视消息: {content.get('text', '')}") + continue + # 代码块特殊处理 + if content_type == "code" and last_content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + + elif last_content_type == "code" and content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = "" # 更新完整代码以备下次比较 + + elif content_type == "code" and last_content_type == "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + + else: + # 只获取新的 parts + parts = content.get("parts", []) + full_text = ''.join(parts) + new_text = full_text[len(last_full_text):] + last_full_text = full_text # 更新完整文本以备下次比较 + if "\u3010" in new_text and not citation_accumulating: + citation_accumulating = True + citation_buffer = citation_buffer + new_text + print(f"开始积累引用: {citation_buffer}") + elif citation_accumulating: + citation_buffer += new_text + print(f"积累引用: {citation_buffer}") + if citation_accumulating: + if is_valid_citation_format(citation_buffer): + print(f"合法格式: {citation_buffer}") + # 继续积累 + if is_complete_citation_format(citation_buffer): + + # 替换完整的引用格式 + replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + + if(is_potential_citation): + citation_buffer = remaining_text + else: + citation_accumulating = False + citation_buffer = "" + print(f"替换完整的引用格式: {new_text}") + else: + continue + else: + # 不是合法格式,放弃积累并响应 + print(f"不合法格式: {citation_buffer}") + new_text = citation_buffer + citation_accumulating = False + citation_buffer = "" + + + # Python 工具执行输出特殊处理 + if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: + + + full_code_result = ''.join(content.get("text", "")) + new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] + if last_content_type == "code": + new_text = "\n```\n" + new_text + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: + # new_text = content.get("text", "") + "\n```" + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + "\n```\n" + if content_type == "code": + new_text = new_text + "\n```\n" + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = "" # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result + + # print(f"[{datetime.now()}] 收到数据: {data_json}") + # print(f"[{datetime.now()}] 收到的完整文本: {full_text}") + # print(f"[{datetime.now()}] 上次收到的完整文本: {last_full_text}") + # print(f"[{datetime.now()}] 新的文本: {new_text}") + + # 更新 last_content_type + if content_type != None: + last_content_type = content_type if role != "user" else last_content_type + + new_data = { + "id": message.get("id"), + "object": "chat.completion.chunk", + "created": message.get("create_time"), + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(new_text) + }, + "finish_reason": None + } + ] + } + print(f"[{datetime.now()}] 发送消息: {new_text}") + tmp = 'data: ' + json.dumps(new_data) + '\n\n' + # print(f"[{datetime.now()}] 发送数据: {tmp}") + yield 'data: ' + json.dumps(new_data) + '\n\n' + except json.JSONDecodeError: + print("JSON 解析错误") + print(f"[{datetime.now()}] 发送数据: {complete_data}") + if complete_data == 'data: [DONE]\n\n': + print(f"[{datetime.now()}] 会话结束") + yield complete_data + if citation_buffer != "": + new_data = { + "id": message.get("id"), + "object": "chat.completion.chunk", + "created": message.get("create_time"), + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(citation_buffer) + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(new_data) + '\n\n' + # print(f"[{datetime.now()}] 发送数据: {tmp}") + yield 'data: ' + json.dumps(new_data) + '\n\n' + if buffer: + # print(f"[{datetime.now()}] 最后的数据: {buffer}") + delete_conversation(conversation_id, api_key) + try: + buffer_json = json.loads(buffer) + error_message = buffer_json.get("detail", {}).get("message", "未知错误") + error_data = { + "id": str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": datetime.now().isoformat(), + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + error_message + "\n```") + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(error_data) + '\n\n' + print(f"[{datetime.now()}] 发送最后的数据: {tmp}") + yield 'data: ' + json.dumps(error_data) + '\n\n' + except json.JSONDecodeError: + print("JSON 解析错误") + print(f"[{datetime.now()}] 发送最后的数据: {buffer}") + yield buffer + + delete_conversation(conversation_id, api_key) + + + return Response(generate(), mimetype='text/event-stream') + + +# 运行 Flask 应用 +if __name__ == '__main__': + app.run(host='0.0.0.0') + diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..ef160c1 --- /dev/null +++ b/start.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# 记录当前日期和时间 +NOW=$(date +"%Y-%m-%d-%H-%M") + +# 启动Gunicorn +exec gunicorn -w 2 --threads 2 --bind 0.0.0.0:33333 main:app --access-logfile "./log/access-${NOW}.log" --error-logfile "./log/error-${NOW}.log" + +# python3 ./main.py + diff --git a/upload.py b/upload.py new file mode 100644 index 0000000..2c3a71c --- /dev/null +++ b/upload.py @@ -0,0 +1,17 @@ +from flask import Flask, send_from_directory +import os + +app = Flask(__name__) + +# 确保 images 文件夹存在 +os.makedirs('images', exist_ok=True) + +@app.route('/images/') +def get_image(filename): + # 检查文件是否存在 + if not os.path.isfile(os.path.join('images', filename)): + return "文件不存在哦!", 404 + return send_from_directory('images', filename) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=23333)