From beedabb96560bdecf37cab67c7f71508a7dafd88 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Sun, 17 Dec 2023 12:55:38 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E6=94=AF=E6=8C=81/v1/images/generatio?= =?UTF-8?q?ns=E7=BB=98=E5=9B=BE=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Readme.md | 94 +++++++------- main.py | 362 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 408 insertions(+), 48 deletions(-) diff --git a/Readme.md b/Readme.md index c546414..c4fbd1b 100644 --- a/Readme.md +++ b/Readme.md @@ -2,56 +2,17 @@ 为了方便大家将 [Pandora-Next](https://github.com/pandora-next/deploy) 项目与各种其他项目结合完成了本项目。 -本项目支持将 Pandora-Next `proxy` 模式下的 `backend-api` 转为 `/v1/chat/completions` 接口,支持流式和非流式响应。 +本项目支持: + +1. 将 Pandora-Next `proxy` 模式下的 `backend-api` 转为 `/v1/chat/completions` 接口,支持流式和非流式响应。 + +2. 将 Pandora-Next `proxy` 模式下的 `backend-api` 转为 `/v1/images/generations` 接口 如果本项目对你有帮助的话,请点个小星星吧~ ## 更新日志 -### 0.1.2 - -- 紧急修复:GPTS 未携带消息的问题 - -- 请使用 `0.1.0` 和 `0.1.1` 版本的服务尽快升级! - -### 0.1.1 - -- 支持 `gpt-3.5-turbo` 模型 - -### 0.1.0 - -- 重磅更新 - -- 已支持访问大部分的GPTS - -- 注意:本次更新需要更新 `docker-compose.yml` 文件以及 `gpts.json` 文件 - -### 0.0.11 - -- 修复一些偶现的bug - -### 0.0.10 - -- 已支持非流式响应 - -- 更新latest版本镜像 - -### 0.0.9 - -- 修复在 ChatGPT-Next-Web 网页端修改请求接口后出现 `Failed to fetch` 报错的问题 - -### 0.0.8 - -- 增加了对 GPT-4-Mobile 模型的支持,模型名为 `gpt-4-mobile` - -### 0.0.7 - -- 一定程度上修复图片无法正常生成的问题 -- 注意:`docker-compsoe.yml`有更新 - -### 0.0.6 - -- 修复接入ChatGPT-Next-Web后回复会携带上次的回复的Bug +见 `Release` 页面。 ## 注意 @@ -85,7 +46,11 @@ - UPLOAD_BASE_URL:用于dalle模型生成图片的时候展示所用,需要设置为使用如 [ChatGPT-Next-Web](https://github.com/ChatGPTNextWebTeam/ChatGPT-Next-Web) 的用户可以访问到的 Uploader 容器地址,如:http://127.0.0.1:50012 -- KEY_FOR_GPTS_INFO:仅获取 GPTS 信息的 key,需要该 key 能够访问所有配置的 GPTS。后续发送消息仍需要在请求头携带请求所用的 key。 +- KEY_FOR_GPTS_INFO:仅获取 GPTS 信息的 key,需要该 key 能够访问所有配置的 GPTS。后续发送消息仍需要在请求头携带请求所用的 key,如果未配置该项,请将 `gpts.json` 文件修改为: + +```json +{} +``` ## GPTS配置说明 @@ -108,6 +73,43 @@ 注意:使用该配置的时候需要保证正确填写 `docker-compose.yml` 的环境变量 `KEY_FOR_GPTS_INFO`,同时该变量设置的 `key` 允许访问所有配置的 GPTS。 +## 绘图接口使用说明 + +接口URI:`/v1/images/generations` + +请求方式:`POST` + +请求头:正常携带 `Authorization` 和 `Content-Type` 即可,`Authorization` 的值为 `Bearer `,`Content-Type` 的值为 `application/json` + +请求体格式示例: + +```json +{ + "model": "gpt-4-s", + "prompt": "A cute baby sea otter" +} +``` + +请求体参数说明: + +- model:模型名称,需要支持绘图功能,否则绘图结果将为空 + +- prompt:绘图的 Prompt + +响应体格式示例: + +```json +{ + "created": 1702788293, + "data": [ + { + "url": "http://:50012/images/image_20231217044452.png" + } + ], + "reply": "\n```\n{\"size\":\"1024x1024\",\"prompt\":\"A cute baby sea otter floating on its back in calm, clear waters. The otter has soft, fluffy brown fur, and its small, round eyes are shining brightly. It's holding a small starfish in its tiny paws. The sun is setting in the background, casting a golden glow over the scene. The water reflects the colors of the sunset, with gentle ripples around the otter. There are a few seagulls flying in the distance under the pastel-colored sky.\"}Here is the image of a cute baby sea otter floating on its back." +} +``` + ## 示例 以ChatGPT-Next-Web项目的docker-compose部署为例,这里提供一个简单的部署配置文件示例: diff --git a/main.py b/main.py index bffb165..8a5cf5b 100644 --- a/main.py +++ b/main.py @@ -109,8 +109,8 @@ PROXY_API_PREFIX = os.getenv('PROXY_API_PREFIX', '') UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '') KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '') -VERSION = '0.1.2' -UPDATE_INFO = '紧急修复:GPTS未携带发送消息的问题' +VERSION = '0.1.3' +UPDATE_INFO = '增加对官方Dalle接口的兼容' with app.app_context(): # 输出版本信息 @@ -711,6 +711,364 @@ def chat_completions(): else: return Response(generate(), mimetype='text/event-stream') + +@app.route('/v1/images/generations', methods=['POST']) +def images_generations(): + print(f"[{datetime.now()}] New Img Request") + data = request.json + # messages = data.get('messages') + model = data.get('model') + accessible_model_list = get_accessible_model_list() + if model not in accessible_model_list: + return jsonify({"error": "model is not accessible"}), 401 + + prompt = data.get('prompt', '') + + # stream = data.get('stream', False) + + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Bearer '): + return jsonify({"error": "Authorization header is missing or invalid"}), 401 + api_key = auth_header.split(' ')[1] + print(f"api_key: {api_key}") + + image_urls = [] + + messages = [ + { + "role": "user", + "content": prompt, + "hasName": False + } + ] + + upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model) + + # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text + all_new_text = "" + + # 处理流式响应 + def generate(): + nonlocal all_new_text # 引用外部变量 + chat_message_id = generate_unique_id("chatcmpl") + # 当前时间戳 + timestamp = int(time.time()) + + buffer = "" + last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 + last_full_code = "" + last_full_code_result = "" + last_content_type = None # 用于记录上一个消息的内容类型 + conversation_id = '' + citation_buffer = "" + citation_accumulating = False + for chunk in upstream_response.iter_content(chunk_size=1024): + if chunk: + buffer += chunk.decode('utf-8') + # 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容 + if "event: ping" in buffer: + if "data:" in buffer: + buffer = buffer.split("data:", 1)[1] + buffer = "data:" + buffer + # 使用正则表达式移除特定格式的字符串 + # print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n')) + buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) + # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) + + + + while 'data:' in buffer and '\n\n' in buffer: + end_index = buffer.index('\n\n') + 2 + complete_data, buffer = buffer[:end_index], buffer[end_index:] + # 解析 data 块 + try: + data_json = json.loads(complete_data.replace('data: ', '')) + # print(f"data_json: {data_json}") + message = data_json.get("message", {}) + + if message == {} or message == None: + print(f"message 为空: data_json: {data_json}") + + message_status = message.get("status") + content = message.get("content", {}) + role = message.get("author", {}).get("role") + content_type = content.get("content_type") + print(f"content_type: {content_type}") + print(f"last_content_type: {last_content_type}") + + metadata = {} + citations = [] + try: + metadata = message.get("metadata", {}) + citations = metadata.get("citations", []) + except: + pass + name = message.get("author", {}).get("name") + if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": + # 如果是用户发来的消息,直接舍弃 + continue + try: + conversation_id = data_json.get("conversation_id") + print(f"conversation_id: {conversation_id}") + except: + pass + # 只获取新的部分 + new_text = "" + is_img_message = False + parts = content.get("parts", []) + for part in parts: + try: + # print(f"part: {part}") + # print(f"part type: {part.get('content_type')}") + if part.get('content_type') == 'image_asset_pointer': + print(f"find img message~") + is_img_message = True + asset_pointer = part.get('asset_pointer').replace('file-service://', '') + print(f"asset_pointer: {asset_pointer}") + image_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + print(f"download_url: {download_url}") + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + print(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + # new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" + image_link = f"{UPLOAD_BASE_URL}/{today_image_url}" + image_urls.append(image_link) # 将图片链接保存到列表中 + new_text = "" + else: + print(f"下载图片失败: {image_download_response.text}") + if last_content_type == "code": + new_text = new_text + # new_text = "\n```\n" + new_text + print(f"new_text: {new_text}") + is_img_message = True + else: + print(f"获取图片下载链接失败: {image_response.text}") + except: + pass + + + if is_img_message == False: + # print(f"data_json: {data_json}") + if content_type == "multimodal_text" and last_content_type == "code": + new_text = "\n```\n" + content.get("text", "") + elif role == "tool" and name == "dalle.text2im": + print(f"无视消息: {content.get('text', '')}") + continue + # 代码块特殊处理 + if content_type == "code" and last_content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + + elif last_content_type == "code" and content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = "" # 更新完整代码以备下次比较 + + elif content_type == "code" and last_content_type == "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + + else: + # 只获取新的 parts + parts = content.get("parts", []) + full_text = ''.join(parts) + new_text = full_text[len(last_full_text):] + last_full_text = full_text # 更新完整文本以备下次比较 + if "\u3010" in new_text and not citation_accumulating: + citation_accumulating = True + citation_buffer = citation_buffer + new_text + print(f"开始积累引用: {citation_buffer}") + elif citation_accumulating: + citation_buffer += new_text + print(f"积累引用: {citation_buffer}") + if citation_accumulating: + if is_valid_citation_format(citation_buffer): + print(f"合法格式: {citation_buffer}") + # 继续积累 + if is_complete_citation_format(citation_buffer): + + # 替换完整的引用格式 + replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + + if(is_potential_citation): + citation_buffer = remaining_text + else: + citation_accumulating = False + citation_buffer = "" + print(f"替换完整的引用格式: {new_text}") + else: + continue + else: + # 不是合法格式,放弃积累并响应 + print(f"不合法格式: {citation_buffer}") + new_text = citation_buffer + citation_accumulating = False + citation_buffer = "" + + + # Python 工具执行输出特殊处理 + if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: + + + full_code_result = ''.join(content.get("text", "")) + new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] + if last_content_type == "code": + new_text = "\n```\n" + new_text + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: + # new_text = content.get("text", "") + "\n```" + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + "\n```\n" + if content_type == "code": + new_text = new_text + "\n```\n" + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = "" # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result + + # print(f"[{datetime.now()}] 收到数据: {data_json}") + # print(f"[{datetime.now()}] 收到的完整文本: {full_text}") + # print(f"[{datetime.now()}] 上次收到的完整文本: {last_full_text}") + # print(f"[{datetime.now()}] 新的文本: {new_text}") + + # 更新 last_content_type + if content_type != None: + last_content_type = content_type if role != "user" else last_content_type + + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(new_text) + }, + "finish_reason": None + } + ] + } + # print(f"Role: {role}") + print(f"[{datetime.now()}] 发送消息: {new_text}") + tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + # print(f"[{datetime.now()}] 发送数据: {tmp}") + # 累积 new_text + all_new_text += new_text + yield 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + except json.JSONDecodeError: + # print("JSON 解析错误") + print(f"[{datetime.now()}] 发送数据: {complete_data}") + if complete_data == 'data: [DONE]\n\n': + print(f"[{datetime.now()}] 会话结束") + yield complete_data + if citation_buffer != "": + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(citation_buffer) + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(new_data) + '\n\n' + # print(f"[{datetime.now()}] 发送数据: {tmp}") + # 累积 new_text + all_new_text += citation_buffer + yield 'data: ' + json.dumps(new_data) + '\n\n' + if buffer: + # print(f"[{datetime.now()}] 最后的数据: {buffer}") + delete_conversation(conversation_id, api_key) + try: + buffer_json = json.loads(buffer) + error_message = buffer_json.get("detail", {}).get("message", "未知错误") + error_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + error_message + "\n```") + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(error_data) + '\n\n' + print(f"[{datetime.now()}] 发送最后的数据: {tmp}") + # 累积 new_text + all_new_text += ''.join("```\n" + error_message + "\n```") + yield 'data: ' + json.dumps(error_data) + '\n\n' + except: + # print("JSON 解析错误") + print(f"[{datetime.now()}] 发送最后的数据: {buffer}") + yield buffer + + delete_conversation(conversation_id, api_key) + + # 执行流式响应的生成函数来累积 all_new_text + # 迭代生成器对象以执行其内部逻辑 + for _ in generate(): + pass + # 构造响应的 JSON 结构 + response_json = { + "created": int(time.time()), # 使用当前时间戳 + "reply": all_new_text, # 使用累积的文本 + "data": [{"url": url} for url in image_urls] # 将图片链接列表转换为所需格式 + } + + # 返回 JSON 响应 + return jsonify(response_json) + + @app.after_request def after_request(response): response.headers.add('Access-Control-Allow-Origin', '*')