From 5b61e0f2ddfa1215a2cb4ea00e8a530b0ef7742e Mon Sep 17 00:00:00 2001 From: Wizerd Date: Sat, 3 Feb 2024 10:45:15 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E5=90=8C=E6=97=B6=E6=94=AF=E6=8C=81ws?= =?UTF-8?q?s=E5=92=8Csse=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 212 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 199 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index ce381f9..d35e3ca 100644 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ def load_config(file_path): CONFIG = load_config('./data/config.json') -LOG_LEVEL = CONFIG.get('log_level', 'DEBUG').upper() +LOG_LEVEL = CONFIG.get('log_level', 'INFO').upper() NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true' # 使用 get 方法获取配置项,同时提供默认值 @@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.1' +VERSION = '0.7.2' # VERSION = 'test' -UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题' +UPDATE_INFO = '同时支持wss和sse响应' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -1628,10 +1628,14 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m ws.close() break upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model) - upstream_response_json = upstream_response.json() - upstream_wss_url = upstream_response_json.get("wss_url", None) - upstream_response_id = upstream_response_json.get("response_id", None) - context["response_id"] = upstream_response_id + upstream_wss_url = None + try: + upstream_response_json = upstream_response.json() + upstream_wss_url = upstream_response_json.get("wss_url", None) + upstream_response_id = upstream_response_json.get("response_id", None) + context["response_id"] = upstream_response_id + except json.JSONDecodeError: + pass if upstream_wss_url is not None: logger.debug(f"start wss...") ws = websocket.WebSocketApp(wss_url, @@ -1642,9 +1646,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m ws.run_forever() logger.debug(f"end wss...") - - else: - logger.error(f"注册 wss 失败") + elif upstream_response.status_code != 200: + logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") complete_data = 'data: [DONE]\n\n' timestamp = context["timestamp"] @@ -1657,7 +1660,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m { "index": 0, "delta": { - "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```") + "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") }, "finish_reason": None } @@ -1667,10 +1670,193 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m data_queue.put(q_data) q_data = complete_data - data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")) + data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) data_queue.put(q_data) stop_event.set() + else: + logger.error(f"检测到非wss响应...") + old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format) + # complete_data = 'data: [DONE]\n\n' + # timestamp = context["timestamp"] + + # new_data = { + # "id": chat_message_id, + # "object": "chat.completion.chunk", + # "created": timestamp, + # "model": model, + # "choices": [ + # { + # "index": 0, + # "delta": { + # "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```") + # }, + # "finish_reason": None + # } + # ] + # } + # q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + # data_queue.put(q_data) + + # q_data = complete_data + # data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")) + # data_queue.put(q_data) + # stop_event.set() + + +def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format): + all_new_text = "" + + first_output = True + + # 当前时间戳 + timestamp = int(time.time()) + + buffer = "" + last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 + last_full_code = "" + last_full_code_result = "" + last_content_type = None # 用于记录上一个消息的内容类型 + conversation_id = '' + citation_buffer = "" + citation_accumulating = False + file_output_buffer = "" + file_output_accumulating = False + execution_output_image_url_buffer = "" + execution_output_image_id_buffer = "" + try: + for chunk in upstream_response.iter_content(chunk_size=1024): + if stop_event.is_set(): + logger.info(f"接受到停止信号,停止数据处理线程") + break + if chunk: + buffer += chunk.decode('utf-8') + # 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容 + if "event: ping" in buffer: + if "data:" in buffer: + buffer = buffer.split("data:", 1)[1] + buffer = "data:" + buffer + # 使用正则表达式移除特定格式的字符串 + # print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n')) + buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) + # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) + + + + while 'data:' in buffer and '\n\n' in buffer: + end_index = buffer.index('\n\n') + 2 + complete_data, buffer = buffer[:end_index], buffer[end_index:] + # 解析 data 块 + try: + data_json = json.loads(complete_data.replace('data: ', '')) + logger.debug(f"data_json: {data_json}") + # print(f"data_json: {data_json}") + all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text) + except json.JSONDecodeError: + # print("JSON 解析错误") + logger.info(f"发送数据: {complete_data}") + if complete_data == 'data: [DONE]\n\n': + logger.info(f"会话结束") + q_data = complete_data + data_queue.put(('all_new_text', all_new_text)) + data_queue.put(q_data) + last_data_time[0] = time.time() + if stop_event.is_set(): + break + if citation_buffer != "": + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(citation_buffer) + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(new_data) + '\n\n' + # print(f"发送数据: {tmp}") + # 累积 new_text + all_new_text += citation_buffer + q_data = 'data: ' + json.dumps(new_data) + '\n\n' + data_queue.put(q_data) + last_data_time[0] = time.time() + if buffer: + # print(f"最后的数据: {buffer}") + # delete_conversation(conversation_id, api_key) + try: + buffer_json = json.loads(buffer) + logger.info(f"最后的缓存数据: {buffer_json}") + error_message = buffer_json.get("detail", {}).get("message", "未知错误") + error_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + error_message + "\n```") + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(error_data) + '\n\n' + logger.info(f"发送最后的数据: {tmp}") + # 累积 new_text + all_new_text += ''.join("```\n" + error_message + "\n```") + q_data = 'data: ' + json.dumps(error_data) + '\n\n' + data_queue.put(q_data) + last_data_time[0] = time.time() + complete_data = 'data: [DONE]\n\n' + logger.info(f"会话结束") + q_data = complete_data + data_queue.put(('all_new_text', all_new_text)) + data_queue.put(q_data) + last_data_time[0] = time.time() + except: + # print("JSON 解析错误") + logger.info(f"发送最后的数据: {buffer}") + error_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + buffer + "\n```") + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(error_data) + '\n\n' + q_data = tmp + data_queue.put(q_data) + last_data_time[0] = time.time() + complete_data = 'data: [DONE]\n\n' + logger.info(f"会话结束") + q_data = complete_data + data_queue.put(('all_new_text', all_new_text)) + data_queue.put(q_data) + last_data_time[0] = time.time() + except Exception as e: + logger.error(f"Exception: {e}") + complete_data = 'data: [DONE]\n\n' + logger.info(f"会话结束") + q_data = complete_data + data_queue.put(('all_new_text', all_new_text)) + data_queue.put(q_data) + last_data_time[0] = time.time() def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): all_new_text = "" @@ -1715,7 +1901,7 @@ def register_websocket(api_key): } response = requests.post(url, headers=headers) response_json = response.json() - # logger.debug(f"register_websocket response: {response_json}") + logger.debug(f"register_websocket response: {response_json}") wss_url = response_json.get("wss_url", None) return wss_url