From f314ce95062ec8f5eb8fbb144e857b79a6417494 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Fri, 2 Feb 2024 20:49:37 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E9=80=82=E9=85=8DWSS=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 2 +- main.py | 1584 +++++++++++++++++++++++++--------------------------- 2 files changed, 773 insertions(+), 813 deletions(-) diff --git a/Dockerfile b/Dockerfile index 83e845e..a22d120 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ RUN apt update && apt install -y jq # RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple # 安装任何所需的依赖项 -RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis +RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken fake_useragent redis websocket-client # 在容器启动时运行 Flask 应用 CMD ["/app/start.sh"] diff --git a/main.py b/main.py index af8550b..f7be398 100644 --- a/main.py +++ b/main.py @@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.6.0' +VERSION = '0.7.0' # VERSION = 'test' -UPDATE_INFO = '去除PandoraNext相关服务依赖选项,并修改部分配置名,Respect Pandora!' +UPDATE_INFO = '适配WSS输出' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -947,7 +947,703 @@ def replace_sandbox(text, conversation_id, message_id, api_key): replaced_text = re.sub(r'\(sandbox:([^)]+)\)', replace_match, text) return replaced_text -def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model): + + +def generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, message_id, model): + model_config = find_model_config(model) + if model_config: + gizmo_info = model_config['config'] + gizmo_id = gizmo_info['gizmo']['id'] + payload = { + "action": "next", + "messages": [ + { + "id": generate_custom_uuid_v4(), + "author": { + "role": author_role, + "name": author_name + }, + "content": { + "content_type": "text", + "parts": [ + "" + ] + }, + "recipient": "all", + "metadata": { + "jit_plugin_data": { + "from_client": { + "user_action": { + "data": { + "type": "always_allow", + "operation_hash": operation_hash + }, + "target_message_id": target_message_id + } + } + } + } + } + ], + "conversation_id": conversation_id, + "parent_message_id": message_id, + "model": "gpt-4-gizmo", + "timezone_offset_min": -480, + "history_and_training_disabled": False, + "arkose_token": None, + "conversation_mode": { + "gizmo": gizmo_info, + "kind": "gizmo_interaction", + "gizmo_id": gizmo_id + }, + "force_paragen": False, + "force_rate_limit": False + } + return payload + else: + return None + +# 定义发送请求的函数 +def send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, conversation_id, model, api_key): + url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation" + headers = { + "Authorization": f"Bearer {api_key}" + } + # 查找模型配置 + model_config = find_model_config(model) + ori_model_name = '' + if model_config: + # 检查是否有 ori_name + ori_model_name = model_config.get('ori_name', model) + payload = generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, message_id, model) + token = None + payload['arkose_token'] = token + logger.debug(f"payload: {payload}") + if NEED_DELETE_CONVERSATION_AFTER_RESPONSE: + logger.info(f"是否保留会话: {NEED_DELETE_CONVERSATION_AFTER_RESPONSE == False}") + payload['history_and_training_disabled'] = True + logger.debug(f"request headers: {headers}") + logger.debug(f"payload: {payload}") + logger.info(f"继续请求上游接口") + try: + response = requests.post(url, headers=headers, json=payload, stream=True, verify=False, timeout=30) + logger.info(f"成功与上游接口建立连接") + # print(response) + return response + except requests.exceptions.Timeout: + # 处理超时情况 + logger.error("请求超时") + +def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text): + # print(f"data_json: {data_json}") + message = data_json.get("message", {}) + + if message == {} or message == None: + logger.debug(f"message 为空: data_json: {data_json}") + + message_id = message.get("id") + + message_status = message.get("status") + content = message.get("content", {}) + role = message.get("author", {}).get("role") + content_type = content.get("content_type") + # print(f"content_type: {content_type}") + # print(f"last_content_type: {last_content_type}") + + metadata = {} + citations = [] + try: + metadata = message.get("metadata", {}) + citations = metadata.get("citations", []) + except: + pass + name = message.get("author", {}).get("name") + + # 开始处理action确认事件 + jit_plugin_data = metadata.get("jit_plugin_data", {}) + from_server = jit_plugin_data.get("from_server", {}) + action_type = from_server.get("type", "") + if message_status == "finished_successfully" and action_type == "confirm_action": + logger.info(f"监测到action确认事件") + # 提取所需信息 + message_id = message.get("id", "") + author_role = message.get("author", {}).get("role", "") + author_name = message.get("author", {}).get("name", "") + actions = from_server.get("body", {}).get("actions", []) + target_message_id = "" + operation_hash = "" + + for action in actions: + if action.get("type") == "always_allow": + target_message_id = action.get("always_allow", {}).get("target_message_id", "") + operation_hash = action.get("always_allow", {}).get("operation_hash", "") + break + + conversation_id = data_json.get("conversation_id", "") + upstream_response = send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, conversation_id, model, api_key) + if upstream_response == None: + complete_data = 'data: [DONE]\n\n' + logger.info(f"会话超时") + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("{\n\"error\": \"Something went wrong...\"\n}") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + q_data = complete_data + data_queue.put(('all_new_text', "{\n\"error\": \"Something went wrong...\"\n}")) + data_queue.put(q_data) + last_data_time[0] = time.time() + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + + if upstream_response.status_code != 200: + complete_data = 'data: [DONE]\n\n' + logger.info(f"会话出错") + logger.error(f"upstream_response status code: {upstream_response.status_code}") + logger.error(f"upstream_response: {upstream_response.text}") + tmp_message = "Something went wrong..." + + try: + upstream_response_text = upstream_response.text + # 解析 JSON 字符串 + parsed_response = json.loads(upstream_response_text) + + # 尝试提取 message 字段 + tmp_message = parsed_response.get("detail", {}).get("message", None) + tmp_code = parsed_response.get("detail", {}).get("code", None) + if tmp_code == "account_deactivated" or tmp_code == "model_cap_exceeded": + logger.error(f"账号被封禁或超限,异常代码: {tmp_code}") + + except json.JSONDecodeError: + # 如果 JSON 解析失败,则记录错误 + logger.error("Failed to parse the upstream response as JSON") + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n{\n\"error\": \""+ tmp_message +"\"\n}\n```") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + + q_data = complete_data + data_queue.put(('all_new_text', "```\n{\n\"error\": \""+ tmp_message +"\"\n}```")) + data_queue.put(q_data) + last_data_time[0] = time.time() + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + + logger.info(f"action确认事件处理成功, 上游响应数据结构类型: {type(upstream_response)}") + + upstream_response_json = upstream_response.json() + upstream_response_id = upstream_response_json.get("response_id", "") + + buffer = "" + last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 + last_full_code = "" + last_full_code_result = "" + last_content_type = None # 用于记录上一个消息的内容类型 + conversation_id = '' + citation_buffer = "" + citation_accumulating = False + file_output_buffer = "" + file_output_accumulating = False + execution_output_image_url_buffer = "" + execution_output_image_id_buffer = "" + + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, upstream_response_id + + + if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": + # 如果是用户发来的消息,直接舍弃 + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + try: + conversation_id = data_json.get("conversation_id") + # print(f"conversation_id: {conversation_id}") + if conversation_id: + data_queue.put(('conversation_id', conversation_id)) + except: + pass + # 只获取新的部分 + new_text = "" + is_img_message = False + parts = content.get("parts", []) + for part in parts: + try: + # print(f"part: {part}") + # print(f"part type: {part.get('content_type')}") + if part.get('content_type') == 'image_asset_pointer': + logger.debug(f"find img message~") + is_img_message = True + asset_pointer = part.get('asset_pointer').replace('file-service://', '') + logger.debug(f"asset_pointer: {asset_pointer}") + image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + logger.debug(f"download_url: {download_url}") + if USE_OAIUSERCONTENT_URL == True: + if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + new_text = f"\n![image]({download_url})\n[下载链接]({download_url})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + if all_new_text != "": + new_text = f"\n图片链接:{download_url}\n" + else: + new_text = f"图片链接:{download_url}\n" + if response_format == "url": + data_queue.put(('image_url', f"{download_url}")) + else: + image_download_response = requests.get(download_url) + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + # 使用base64编码图片 + image_base64 = base64.b64encode(image_data).decode('utf-8') + data_queue.put(('image_url', image_base64)) + else: + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + if response_format == "url": + data_queue.put(('image_url', f"{UPLOAD_BASE_URL}/{today_image_url}")) + else: + # 使用base64编码图片 + image_base64 = base64.b64encode(image_data).decode('utf-8') + data_queue.put(('image_url', image_base64)) + if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + if all_new_text != "": + new_text = f"\n图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" + else: + new_text = f"图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" + else: + logger.error(f"下载图片失败: {image_download_response.text}") + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = new_text + else: + new_text = "\n```\n" + new_text + logger.debug(f"new_text: {new_text}") + is_img_message = True + else: + logger.error(f"获取图片下载链接失败: {image_response.text}") + except: + pass + + + if is_img_message == False: + # print(f"data_json: {data_json}") + if content_type == "multimodal_text" and last_content_type == "code": + new_text = "\n```\n" + content.get("text", "") + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = content.get("text", "") + elif role == "tool" and name == "dalle.text2im": + logger.debug(f"无视消息: {content.get('text', '')}") + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + # 代码块特殊处理 + if content_type == "code" and last_content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + elif last_content_type == "code" and content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = "" # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + elif content_type == "code" and last_content_type == "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + else: + # 只获取新的 parts + parts = content.get("parts", []) + full_text = ''.join(parts) + logger.debug(f"last_full_text: {last_full_text}") + new_text = full_text[len(last_full_text):] + if full_text != '': + last_full_text = full_text # 更新完整文本以备下次比较 + logger.debug(f"full_text: {full_text}") + logger.debug(f"new_text: {new_text}") + if "\u3010" in new_text and not citation_accumulating: + citation_accumulating = True + citation_buffer = citation_buffer + new_text + logger.debug(f"开始积累引用: {citation_buffer}") + elif citation_accumulating: + citation_buffer += new_text + logger.debug(f"积累引用: {citation_buffer}") + if citation_accumulating: + if is_valid_citation_format(citation_buffer): + logger.debug(f"合法格式: {citation_buffer}") + # 继续积累 + if is_complete_citation_format(citation_buffer): + + # 替换完整的引用格式 + replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) + # print(replaced_text) # 输出替换后的文本 + + new_text = replaced_text + + if(is_potential_citation): + citation_buffer = remaining_text + else: + citation_accumulating = False + citation_buffer = "" + logger.debug(f"替换完整的引用格式: {new_text}") + else: + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + else: + # 不是合法格式,放弃积累并响应 + logger.debug(f"不合法格式: {citation_buffer}") + new_text = citation_buffer + citation_accumulating = False + citation_buffer = "" + + if "(" in new_text and not file_output_accumulating and not citation_accumulating: + file_output_accumulating = True + file_output_buffer = file_output_buffer + new_text + + logger.debug(f"开始积累文件输出: {file_output_buffer}") + logger.debug(f"file_output_buffer: {file_output_buffer}") + logger.debug(f"new_text: {new_text}") + elif file_output_accumulating: + file_output_buffer += new_text + logger.debug(f"积累文件输出: {file_output_buffer}") + if file_output_accumulating: + if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer): + logger.debug(f"合法文件输出格式: {file_output_buffer}") + # 继续积累 + if is_complete_sandbox_format(file_output_buffer): + # 替换完整的引用格式 + logger.info(f'complete_sandbox data_json {data_json}') + replaced_text = replace_sandbox(file_output_buffer, conversation_id, message_id, api_key) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + file_output_accumulating = False + file_output_buffer = "" + logger.debug(f"替换完整的文件输出格式: {new_text}") + else: + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + else: + # 不是合法格式,放弃积累并响应 + logger.debug(f"不合法格式: {file_output_buffer}") + new_text = file_output_buffer + file_output_accumulating = False + file_output_buffer = "" + + # Python 工具执行输出特殊处理 + if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: + + + full_code_result = ''.join(content.get("text", "")) + new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = "\n```\n" + new_text + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: + # new_text = content.get("text", "") + "\n```" + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + "\n```\n" + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + tmp_new_text = new_text + if execution_output_image_url_buffer != "": + if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") + logger.debug(f"BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") + new_text = tmp_new_text + f"![image]({execution_output_image_url_buffer})\n[下载链接]({execution_output_image_url_buffer})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") + logger.debug(f"BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") + new_text = tmp_new_text + f"图片链接:{execution_output_image_url_buffer}\n" + execution_output_image_url_buffer = "" + + if content_type == "code": + new_text = new_text + "\n```\n" + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = "" # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = full_code_result[len(last_full_code_result):] + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result + + # 其余Action执行输出特殊处理 + if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: + new_text = "" + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = "\n```\n" + new_text + + # 检查 new_text 中是否包含 <> + if "<>" in last_full_code_result: + # 进行提取操作 + aggregate_result = message.get("metadata", {}).get("aggregate_result", {}) + if aggregate_result: + messages = aggregate_result.get("messages", []) + for msg in messages: + if msg.get("message_type") == "image": + image_url = msg.get("image_url") + if image_url: + # 从 image_url 提取所需的字段 + image_file_id = image_url.split('://')[-1] + logger.info(f"提取到的图片文件ID: {image_file_id}") + if image_file_id != execution_output_image_id_buffer: + image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{image_file_id}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + logger.debug(f"download_url: {download_url}") + if USE_OAIUSERCONTENT_URL == True: + execution_output_image_url_buffer = download_url + + else: + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + execution_output_image_url_buffer = f"{UPLOAD_BASE_URL}/{today_image_url}" + + else: + logger.error(f"下载图片失败: {image_download_response.text}") + + execution_output_image_id_buffer = image_file_id + + # 从 new_text 中移除 <> + new_text = new_text.replace("<>", "图片生成中,请稍后\n") + + # print(f"收到数据: {data_json}") + # print(f"收到的完整文本: {full_text}") + # print(f"上次收到的完整文本: {last_full_text}") + # print(f"新的文本: {new_text}") + + # 更新 last_content_type + if content_type != None: + last_content_type = content_type if role != "user" else last_content_type + + model_slug = message.get("metadata", {}).get("model_slug") or model + + if first_output: + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model_slug, + "choices": [ + { + "index": 0, + "delta": {"role":"assistant"}, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + logger.info(f"开始流式响应...") + first_output = False + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model_slug, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(new_text) + }, + "finish_reason": None + } + ] + } + # print(f"Role: {role}") + # logger.info(f".") + tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + # print(f"发送数据: {tmp}") + # 累积 new_text + all_new_text += new_text + tmp_t = new_text.replace('\n', '\\n') + logger.info(f"Send: {tmp_t}") + + + # if new_text != None: + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + last_data_time[0] = time.time() + if stop_event.is_set(): + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None + +import websocket +import base64 + +def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): + headers = { + "Sec-Ch-Ua-Mobile": "?0", + "User-Agent": ua.random + } + context = { + "all_new_text": "", + "first_output": True, + "timestamp": int(time.time()), + "buffer": "", + "last_full_text": "", + "last_full_code": "", + "last_full_code_result": "", + "last_content_type": None, + "conversation_id": "", + "citation_buffer": "", + "citation_accumulating": False, + "file_output_buffer": "", + "file_output_accumulating": False, + "execution_output_image_url_buffer": "", + "execution_output_image_id_buffer": "" + } + + def on_message(ws, message): + # logger.debug(f"on_message: {message}") + if stop_event.is_set(): + logger.info(f"接受到停止信号,停止数据处理线程") + return + result_json = json.loads(message) + result_id = result_json.get('response_id', '') + if result_id != context["response_id"]: + logger.info(f"response_id 不匹配,忽略") + return + body = result_json.get('body', '') + # logger.debug("wss result: " + str(result_json)) + if body: + buffer_data = base64.b64decode(body).decode('utf-8') + end_index = buffer_data.index('\n\n') + 2 + complete_data, _ = buffer_data[:end_index], buffer_data[end_index:] + # logger.debug(f"complete_data: {complete_data}") + try: + data_json = json.loads(complete_data.replace('data: ', '')) + logger.debug(f"data_json: {data_json}") + + context["all_new_text"], context["first_output"], context["last_full_text"], context["last_full_code"], context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context["citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context["file_output_accumulating"], context["execution_output_image_url_buffer"], context["execution_output_image_id_buffer"], allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, context["timestamp"], context["first_output"], context["last_full_text"], context["last_full_code"], context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context["citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context["file_output_accumulating"], context["execution_output_image_url_buffer"], context["execution_output_image_id_buffer"], context["all_new_text"]) + + if allow_id: + context["response_id"] = allow_id + except json.JSONDecodeError: + logger.error(f"Failed to parse the response as JSON: {complete_data}") + if complete_data == 'data: [DONE]\n\n': + logger.info(f"会话结束") + q_data = complete_data + data_queue.put(('all_new_text', context["all_new_text"])) + data_queue.put(q_data) + q_data = complete_data + data_queue.put(q_data) + stop_event.set() + ws.close() + + + def on_error(ws, error): + logger.error(error) + + def on_close(ws, b, c): + logger.debug("wss closed") + + def on_open(ws): + def run(*args): + logger.debug(f"on_open: wss") + while True: + if stop_event.is_set(): + ws.close() + break + upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model) + upstream_response_json = upstream_response.json() + upstream_wss_url = upstream_response_json.get("wss_url", None) + upstream_response_id = upstream_response_json.get("response_id", None) + context["response_id"] = upstream_response_id + if upstream_wss_url is not None: + logger.debug(f"start wss...") + ws = websocket.WebSocketApp(wss_url, + on_message = on_message, + on_error = on_error, + on_close = on_close) + ws.on_open = on_open + ws.run_forever() + + logger.debug(f"end wss...") + + +def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): all_new_text = "" first_output = True @@ -967,501 +1663,34 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_ file_output_accumulating = False execution_output_image_url_buffer = "" execution_output_image_id_buffer = "" - try: - for chunk in upstream_response.iter_content(chunk_size=1024): - if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程") - break - if chunk: - buffer += chunk.decode('utf-8') - # 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容 - if "event: ping" in buffer: - if "data:" in buffer: - buffer = buffer.split("data:", 1)[1] - buffer = "data:" + buffer - # 使用正则表达式移除特定格式的字符串 - # print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n')) - buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) - # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) + + wss_url = register_websocket(api_key) + # response_json = upstream_response.json() + # wss_url = response_json.get("wss_url", None) + # logger.info(f"wss_url: {wss_url}") + # 如果存在 wss_url,使用 WebSocket 连接获取数据 + if wss_url: + process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) + + while True: + if stop_event.is_set(): + logger.info(f"接受到停止信号,停止数据处理线程") + + break - while 'data:' in buffer and '\n\n' in buffer: - end_index = buffer.index('\n\n') + 2 - complete_data, buffer = buffer[:end_index], buffer[end_index:] - # 解析 data 块 - try: - data_json = json.loads(complete_data.replace('data: ', '')) - # print(f"data_json: {data_json}") - message = data_json.get("message", {}) +def register_websocket(api_key): + url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/register-websocket" + headers = { + "Authorization": f"Bearer {api_key}" + } + response = requests.post(url, headers=headers) + response_json = response.json() + logger.debug(f"register_websocket response: {response_json}") + wss_url = response_json.get("wss_url", None) + return wss_url - if message == {} or message == None: - logger.debug(f"message 为空: data_json: {data_json}") - - message_id = message.get("id") - message_status = message.get("status") - content = message.get("content", {}) - role = message.get("author", {}).get("role") - content_type = content.get("content_type") - print(f"content_type: {content_type}") - print(f"last_content_type: {last_content_type}") - - metadata = {} - citations = [] - try: - metadata = message.get("metadata", {}) - citations = metadata.get("citations", []) - except: - pass - name = message.get("author", {}).get("name") - if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": - # 如果是用户发来的消息,直接舍弃 - continue - try: - conversation_id = data_json.get("conversation_id") - # print(f"conversation_id: {conversation_id}") - if conversation_id: - data_queue.put(('conversation_id', conversation_id)) - except: - pass - # 只获取新的部分 - new_text = "" - is_img_message = False - parts = content.get("parts", []) - for part in parts: - try: - # print(f"part: {part}") - # print(f"part type: {part.get('content_type')}") - if part.get('content_type') == 'image_asset_pointer': - logger.debug(f"find img message~") - is_img_message = True - asset_pointer = part.get('asset_pointer').replace('file-service://', '') - logger.debug(f"asset_pointer: {asset_pointer}") - image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" - - headers = { - "Authorization": f"Bearer {api_key}" - } - image_response = requests.get(image_url, headers=headers) - - if image_response.status_code == 200: - download_url = image_response.json().get('download_url') - logger.debug(f"download_url: {download_url}") - if USE_OAIUSERCONTENT_URL == True: - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - new_text = f"\n![image]({download_url})\n[下载链接]({download_url})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - if all_new_text != "": - new_text = f"\n图片链接:{download_url}\n" - else: - new_text = f"图片链接:{download_url}\n" - else: - # 从URL下载图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) - # print(f"image_download_response: {image_download_response.text}") - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - today_image_url = save_image(image_data) # 保存图片,并获取文件名 - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - if all_new_text != "": - new_text = f"\n图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" - else: - new_text = f"图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" - else: - logger.error(f"下载图片失败: {image_download_response.text}") - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = new_text - else: - new_text = "\n```\n" + new_text - - logger.debug(f"new_text: {new_text}") - is_img_message = True - else: - logger.error(f"获取图片下载链接失败: {image_response.text}") - except: - pass - - - if is_img_message == False: - # print(f"data_json: {data_json}") - if content_type == "multimodal_text" and last_content_type == "code": - new_text = "\n```\n" + content.get("text", "") - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = content.get("text", "") - elif role == "tool" and name == "dalle.text2im": - logger.debug(f"无视消息: {content.get('text', '')}") - continue - # 代码块特殊处理 - if content_type == "code" and last_content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - elif last_content_type == "code" and content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = "" # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - elif content_type == "code" and last_content_type == "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - else: - # 只获取新的 parts - parts = content.get("parts", []) - full_text = ''.join(parts) - new_text = full_text[len(last_full_text):] - if full_text != '': - last_full_text = full_text # 更新完整文本以备下次比较 - if "\u3010" in new_text and not citation_accumulating: - citation_accumulating = True - citation_buffer = citation_buffer + new_text - # print(f"开始积累引用: {citation_buffer}") - elif citation_accumulating: - citation_buffer += new_text - # print(f"积累引用: {citation_buffer}") - if citation_accumulating: - if is_valid_citation_format(citation_buffer): - # print(f"合法格式: {citation_buffer}") - # 继续积累 - if is_complete_citation_format(citation_buffer): - - # 替换完整的引用格式 - replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) - # print(replaced_text) # 输出替换后的文本 - new_text = replaced_text - - if(is_potential_citation): - citation_buffer = remaining_text - else: - citation_accumulating = False - citation_buffer = "" - # print(f"替换完整的引用格式: {new_text}") - else: - continue - else: - # 不是合法格式,放弃积累并响应 - # print(f"不合法格式: {citation_buffer}") - new_text = citation_buffer - citation_accumulating = False - citation_buffer = "" - - if "(" in new_text and not file_output_accumulating and not citation_accumulating: - file_output_accumulating = True - file_output_buffer = file_output_buffer + new_text - logger.debug(f"开始积累文件输出: {file_output_buffer}") - elif file_output_accumulating: - file_output_buffer += new_text - logger.debug(f"积累文件输出: {file_output_buffer}") - if file_output_accumulating: - if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer): - logger.debug(f"合法文件输出格式: {file_output_buffer}") - # 继续积累 - if is_complete_sandbox_format(file_output_buffer): - # 替换完整的引用格式 - replaced_text = replace_sandbox(file_output_buffer, conversation_id, message_id, api_key) - # print(replaced_text) # 输出替换后的文本 - new_text = replaced_text - file_output_accumulating = False - file_output_buffer = "" - logger.debug(f"替换完整的文件输出格式: {new_text}") - else: - continue - else: - # 不是合法格式,放弃积累并响应 - logger.debug(f"不合法格式: {file_output_buffer}") - new_text = file_output_buffer - file_output_accumulating = False - file_output_buffer = "" - - - # Python 工具执行输出特殊处理 - if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: - full_code_result = ''.join(content.get("text", "")) - new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = "\n```\n" + new_text - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: - # new_text = content.get("text", "") + "\n```" - full_code_result = ''.join(content.get("text", "")) - new_text = full_code_result[len(last_full_code_result):] + "\n```\n" - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - tmp_new_text = new_text - if execution_output_image_url_buffer != "": - if ((BOT_MODE_ENABLED == False) or (BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") - logger.debug(f"BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") - new_text = tmp_new_text + f"![image]({execution_output_image_url_buffer})\n[下载链接]({execution_output_image_url_buffer})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") - logger.debug(f"BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") - new_text = tmp_new_text + f"图片链接:{execution_output_image_url_buffer}\n" - execution_output_image_url_buffer = "" - - if content_type == "code": - new_text = new_text + "\n```\n" - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = "" # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: - full_code_result = ''.join(content.get("text", "")) - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = full_code_result[len(last_full_code_result):] - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result - - # 其余Action执行输出特殊处理 - if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: - new_text = "" - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = "\n```\n" + new_text - - - # 检查 new_text 中是否包含 <> - if "<>" in last_full_code_result: - # 进行提取操作 - aggregate_result = message.get("metadata", {}).get("aggregate_result", {}) - if aggregate_result: - messages = aggregate_result.get("messages", []) - for msg in messages: - if msg.get("message_type") == "image": - image_url = msg.get("image_url") - if image_url: - # 从 image_url 提取所需的字段 - image_file_id = image_url.split('://')[-1] - logger.info(f"提取到的图片文件ID: {image_file_id}") - if image_file_id != execution_output_image_id_buffer: - image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{image_file_id}/download" - - headers = { - "Authorization": f"Bearer {api_key}" - } - image_response = requests.get(image_url, headers=headers) - - if image_response.status_code == 200: - download_url = image_response.json().get('download_url') - logger.debug(f"download_url: {download_url}") - if USE_OAIUSERCONTENT_URL == True: - execution_output_image_url_buffer = download_url - - else: - # 从URL下载图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) - # print(f"image_download_response: {image_download_response.text}") - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - today_image_url = save_image(image_data) # 保存图片,并获取文件名 - execution_output_image_url_buffer = f"{UPLOAD_BASE_URL}/{today_image_url}" - - else: - logger.error(f"下载图片失败: {image_download_response.text}") - - execution_output_image_id_buffer = image_file_id - - # 从 new_text 中移除 <> - new_text = new_text.replace("<>", "图片生成中,请稍后\n") - - # print(f"收到数据: {data_json}") - # print(f"收到的完整文本: {full_text}") - # print(f"上次收到的完整文本: {last_full_text}") - # print(f"新的文本: {new_text}") - - # 更新 last_content_type - if content_type != None: - last_content_type = content_type if role != "user" else last_content_type - - model_slug = message.get("metadata", {}).get("model_slug") or model - - - if first_output: - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model_slug, - "choices": [ - { - "index": 0, - "delta": {"role":"assistant"}, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - first_output = False - - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model_slug, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join(new_text) - }, - "finish_reason": None - } - ] - } - # print(f"Role: {role}") - logger.info(f"发送消息: {new_text}") - tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - # print(f"发送数据: {tmp}") - # 累积 new_text - all_new_text += new_text - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - last_data_time[0] = time.time() - if stop_event.is_set(): - break - except json.JSONDecodeError: - # print("JSON 解析错误") - logger.info(f"发送数据: {complete_data}") - if complete_data == 'data: [DONE]\n\n': - logger.info(f"会话结束") - q_data = complete_data - data_queue.put(('all_new_text', all_new_text)) - data_queue.put(q_data) - last_data_time[0] = time.time() - if stop_event.is_set(): - break - if citation_buffer != "": - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": message.get("metadata", {}).get("model_slug"), - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join(citation_buffer) - }, - "finish_reason": None - } - ] - } - tmp = 'data: ' + json.dumps(new_data) + '\n\n' - # print(f"发送数据: {tmp}") - # 累积 new_text - all_new_text += citation_buffer - q_data = 'data: ' + json.dumps(new_data) + '\n\n' - data_queue.put(q_data) - last_data_time[0] = time.time() - if buffer: - # print(f"最后的数据: {buffer}") - # delete_conversation(conversation_id, api_key) - try: - buffer_json = json.loads(buffer) - logger.info(f"最后的缓存数据: {buffer_json}") - error_message = buffer_json.get("detail", {}).get("message", "未知错误") - error_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": "error", - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n" + error_message + "\n```") - }, - "finish_reason": None - } - ] - } - tmp = 'data: ' + json.dumps(error_data) + '\n\n' - logger.info(f"发送最后的数据: {tmp}") - # 累积 new_text - all_new_text += ''.join("```\n" + error_message + "\n```") - q_data = 'data: ' + json.dumps(error_data) + '\n\n' - data_queue.put(q_data) - last_data_time[0] = time.time() - complete_data = 'data: [DONE]\n\n' - logger.info(f"会话结束") - q_data = complete_data - data_queue.put(('all_new_text', all_new_text)) - data_queue.put(q_data) - last_data_time[0] = time.time() - except: - # print("JSON 解析错误") - logger.info(f"发送最后的数据: {buffer}") - error_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": "error", - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n" + buffer + "\n```") - }, - "finish_reason": None - } - ] - } - tmp = 'data: ' + json.dumps(error_data) + '\n\n' - q_data = tmp - data_queue.put(q_data) - last_data_time[0] = time.time() - complete_data = 'data: [DONE]\n\n' - logger.info(f"会话结束") - q_data = complete_data - data_queue.put(('all_new_text', all_new_text)) - data_queue.put(q_data) - last_data_time[0] = time.time() - except Exception as e: - logger.error(f"Exception: {e}") - complete_data = 'data: [DONE]\n\n' - logger.info(f"会话结束") - q_data = complete_data - data_queue.put(('all_new_text', all_new_text)) - data_queue.put(q_data) - last_data_time[0] = time.time() def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): while not stop_event.is_set(): @@ -1553,7 +1782,7 @@ def chat_completions(): logger.info(f"api_key: {api_key}") - upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model) + # upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model) # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" @@ -1571,7 +1800,7 @@ def chat_completions(): conversation_id = '' # 启动数据处理线程 - fetcher_thread = threading.Thread(target=data_fetcher, args=(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model)) + fetcher_thread = threading.Thread(target=data_fetcher, args=(data_queue, stop_event, last_data_time, api_key, chat_message_id, model,"url", messages)) fetcher_thread.start() # 启动保活线程 @@ -1706,7 +1935,7 @@ def images_generations(): } ] - upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model) + # upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model) # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" @@ -1714,328 +1943,59 @@ def images_generations(): # 处理流式响应 def generate(): nonlocal all_new_text # 引用外部变量 + data_queue = Queue() + stop_event = threading.Event() + last_data_time = [time.time()] chat_message_id = generate_unique_id("chatcmpl") - # 当前时间戳 - timestamp = int(time.time()) - buffer = "" - last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 - last_full_code = "" - last_full_code_result = "" - last_content_type = None # 用于记录上一个消息的内容类型 + conversation_id_print_tag = False + conversation_id = '' - citation_buffer = "" - citation_accumulating = False - for chunk in upstream_response.iter_content(chunk_size=1024): - if chunk: - buffer += chunk.decode('utf-8') - # 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容 - if "event: ping" in buffer: - if "data:" in buffer: - buffer = buffer.split("data:", 1)[1] - buffer = "data:" + buffer - # 使用正则表达式移除特定格式的字符串 - # print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n')) - buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) - # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) + # 启动数据处理线程 + fetcher_thread = threading.Thread(target=data_fetcher, args=(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) + fetcher_thread.start() + # 启动保活线程 + keep_alive_thread = threading.Thread(target=keep_alive, args=(last_data_time, stop_event, data_queue, model, chat_message_id)) + keep_alive_thread.start() - while 'data:' in buffer and '\n\n' in buffer: - end_index = buffer.index('\n\n') + 2 - complete_data, buffer = buffer[:end_index], buffer[end_index:] - # 解析 data 块 - try: - data_json = json.loads(complete_data.replace('data: ', '')) - # print(f"data_json: {data_json}") - message = data_json.get("message", {}) + try: + while True: + data = data_queue.get() + if isinstance(data, tuple) and data[0] == 'all_new_text': + # 更新 all_new_text + logger.info(f"完整消息: {data[1]}") + all_new_text += data[1] + elif isinstance(data, tuple) and data[0] == 'conversation_id': + if conversation_id_print_tag == False: + logger.info(f"当前会话id: {data[1]}") + conversation_id_print_tag = True + # 更新 conversation_id + conversation_id = data[1] + # print(f"收到会话id: {conversation_id}") + elif isinstance(data, tuple) and data[0] == 'image_url': + # 更新 image_url + image_urls.append(data[1]) + logger.debug(f"收到图片链接: {data[1]}") + elif data == 'data: [DONE]\n\n': + # 接收到结束信号,退出循环 + logger.debug(f"会话结束-外层") + yield data + break + else: + yield data - if message == None: - logger.error(f"message 为空: data_json: {data_json}") + finally: + logger.critical(f"准备结束会话") + stop_event.set() + fetcher_thread.join() + keep_alive_thread.join() - message_status = message.get("status") - content = message.get("content", {}) - role = message.get("author", {}).get("role") - content_type = content.get("content_type") - # logger.debug(f"content_type: {content_type}") - # logger.debug(f"last_content_type: {last_content_type}") + # if conversation_id: + # # print(f"准备删除的会话id: {conversation_id}") + # delete_conversation(conversation_id, cookie, x_authorization) - metadata = {} - citations = [] - try: - metadata = message.get("metadata", {}) - citations = metadata.get("citations", []) - except: - pass - name = message.get("author", {}).get("name") - if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": - # 如果是用户发来的消息,直接舍弃 - continue - try: - conversation_id = data_json.get("conversation_id") - logger.debug(f"conversation_id: {conversation_id}") - except: - pass - # 只获取新的部分 - new_text = "" - is_img_message = False - parts = content.get("parts", []) - for part in parts: - try: - # print(f"part: {part}") - # print(f"part type: {part.get('content_type')}") - if part.get('content_type') == 'image_asset_pointer': - logger.debug(f"find img message~") - is_img_message = True - asset_pointer = part.get('asset_pointer').replace('file-service://', '') - logger.debug(f"asset_pointer: {asset_pointer}") - image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" - - headers = { - "Authorization": f"Bearer {api_key}" - } - image_response = requests.get(image_url, headers=headers) - - if image_response.status_code == 200: - download_url = image_response.json().get('download_url') - logger.debug(f"download_url: {download_url}") - if USE_OAIUSERCONTENT_URL == True and response_format == "url": - image_link = f"{download_url}" - image_urls.append(image_link) # 将图片链接保存到列表中 - new_text = "" - else: - if response_format == "url": - # 从URL下载图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) - # print(f"image_download_response: {image_download_response.text}") - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - today_image_url = save_image(image_data) # 保存图片,并获取文件名 - # new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" - image_link = f"{UPLOAD_BASE_URL}/{today_image_url}" - image_urls.append(image_link) # 将图片链接保存到列表中 - new_text = "" - else: - logger.error(f"下载图片失败: {image_download_response.text}") - else: - # 使用base64编码图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url) - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - image_base64 = base64.b64encode(image_data).decode('utf-8') - image_urls.append(image_base64) - new_text = "" - else: - logger.error(f"下载图片失败: {image_download_response.text}") - if last_content_type == "code": - new_text = new_text - # new_text = "\n```\n" + new_text - logger.debug(f"new_text: {new_text}") - is_img_message = True - else: - logger.error(f"获取图片下载链接失败: {image_response.text}") - except: - pass - - - if is_img_message == False: - # print(f"data_json: {data_json}") - if content_type == "multimodal_text" and last_content_type == "code": - new_text = "\n```\n" + content.get("text", "") - elif role == "tool" and name == "dalle.text2im": - logger.debug(f"无视消息: {content.get('text', '')}") - continue - # 代码块特殊处理 - if content_type == "code" and last_content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - - elif last_content_type == "code" and content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = "" # 更新完整代码以备下次比较 - - elif content_type == "code" and last_content_type == "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - - else: - # 只获取新的 parts - parts = content.get("parts", []) - full_text = ''.join(parts) - new_text = full_text[len(last_full_text):] - last_full_text = full_text # 更新完整文本以备下次比较 - if "\u3010" in new_text and not citation_accumulating: - citation_accumulating = True - citation_buffer = citation_buffer + new_text - logger.debug(f"开始积累引用: {citation_buffer}") - elif citation_accumulating: - citation_buffer += new_text - logger.debug(f"积累引用: {citation_buffer}") - if citation_accumulating: - if is_valid_citation_format(citation_buffer): - logger.debug(f"合法格式: {citation_buffer}") - # 继续积累 - if is_complete_citation_format(citation_buffer): - - # 替换完整的引用格式 - replaced_text, remaining_text, is_potential_citation = replace_complete_citation(citation_buffer, citations) - # print(replaced_text) # 输出替换后的文本 - new_text = replaced_text - - if(is_potential_citation): - citation_buffer = remaining_text - else: - citation_accumulating = False - citation_buffer = "" - logger.debug(f"替换完整的引用格式: {new_text}") - else: - continue - else: - # 不是合法格式,放弃积累并响应 - logger.debug(f"不合法格式: {citation_buffer}") - new_text = citation_buffer - citation_accumulating = False - citation_buffer = "" - - - # Python 工具执行输出特殊处理 - if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: - - - full_code_result = ''.join(content.get("text", "")) - new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] - if last_content_type == "code": - new_text = "\n```\n" + new_text - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: - # new_text = content.get("text", "") + "\n```" - full_code_result = ''.join(content.get("text", "")) - new_text = full_code_result[len(last_full_code_result):] + "\n```\n" - if content_type == "code": - new_text = new_text + "\n```\n" - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = "" # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: - full_code_result = ''.join(content.get("text", "")) - new_text = full_code_result[len(last_full_code_result):] - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result - - # print(f"收到数据: {data_json}") - # print(f"收到的完整文本: {full_text}") - # print(f"上次收到的完整文本: {last_full_text}") - # print(f"新的文本: {new_text}") - - # 更新 last_content_type - if content_type != None: - last_content_type = content_type if role != "user" else last_content_type - - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": message.get("metadata", {}).get("model_slug"), - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join(new_text) - }, - "finish_reason": None - } - ] - } - # print(f"Role: {role}") - logger.info(f"发送消息: {new_text}") - tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - # print(f"发送数据: {tmp}") - # 累积 new_text - all_new_text += new_text - yield 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - except json.JSONDecodeError: - # print("JSON 解析错误") - logger.info(f"发送数据: {complete_data}") - if complete_data == 'data: [DONE]\n\n': - logger.info(f"会话结束") - yield complete_data - if citation_buffer != "": - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": message.get("metadata", {}).get("model_slug"), - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join(citation_buffer) - }, - "finish_reason": None - } - ] - } - tmp = 'data: ' + json.dumps(new_data) + '\n\n' - # print(f"发送数据: {tmp}") - # 累积 new_text - all_new_text += citation_buffer - yield 'data: ' + json.dumps(new_data) + '\n\n' - if buffer: - # print(f"最后的数据: {buffer}") - # delete_conversation(conversation_id, api_key) - try: - buffer_json = json.loads(buffer) - error_message = buffer_json.get("detail", {}).get("message", "未知错误") - error_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": "error", - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n" + error_message + "\n```") - }, - "finish_reason": None - } - ] - } - tmp = 'data: ' + json.dumps(error_data) + '\n\n' - logger.info(f"发送最后的数据: {tmp}") - # 累积 new_text - all_new_text += ''.join("```\n" + error_message + "\n```") - yield 'data: ' + json.dumps(error_data) + '\n\n' - except: - # print("JSON 解析错误") - logger.info(f"发送最后的数据: {buffer}") - yield buffer - - # delete_conversation(conversation_id, api_key) - # 执行流式响应的生成函数来累积 all_new_text # 迭代生成器对象以执行其内部逻辑 for _ in generate(): @@ -2049,7 +2009,7 @@ def images_generations(): "message": all_new_text, # 使用累积的文本作为错误信息 "type": "invalid_request_error", "param": "", - "code": "content_policy_violation" + "code": "image_generate_fail" } } else: @@ -2075,7 +2035,7 @@ def images_generations(): } for base64 in image_urls ] # 将图片链接列表转换为所需格式 } - logger.debug(f"response_json: {response_json}") + # logger.critical(f"response_json: {response_json}") # 返回 JSON 响应 return jsonify(response_json)