From 7bab63226257e97d4dc87627000d3070fa263be1 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Fri, 2 Feb 2024 22:01:26 +0800 Subject: [PATCH] =?UTF-8?q?[fix]=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E5=93=8D=E5=BA=94=E7=BB=98=E5=9B=BE=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=8C=E4=BB=A5=E5=8F=8A=E4=B8=8A=E6=B8=B8?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E5=87=BA=E9=94=99=E6=97=B6=E5=86=85=E5=AD=98?= =?UTF-8?q?=E6=BA=A2=E5=87=BA=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 51 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index f7be398..ce381f9 100644 --- a/main.py +++ b/main.py @@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.0' +VERSION = '0.7.1' # VERSION = 'test' -UPDATE_INFO = '适配WSS输出' +UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -1578,12 +1578,13 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m def on_message(ws, message): # logger.debug(f"on_message: {message}") if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程") + logger.info(f"接受到停止信号,停止 Websocket 处理线程") + ws.close() return result_json = json.loads(message) result_id = result_json.get('response_id', '') if result_id != context["response_id"]: - logger.info(f"response_id 不匹配,忽略") + logger.debug(f"response_id 不匹配,忽略") return body = result_json.get('body', '') # logger.debug("wss result: " + str(result_json)) @@ -1642,6 +1643,34 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m logger.debug(f"end wss...") + else: + logger.error(f"注册 wss 失败") + complete_data = 'data: [DONE]\n\n' + timestamp = context["timestamp"] + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + q_data = complete_data + data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")) + data_queue.put(q_data) + stop_event.set() + def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): all_new_text = "" @@ -1671,12 +1700,11 @@ def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_i # 如果存在 wss_url,使用 WebSocket 连接获取数据 - if wss_url: - process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) + process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) while True: if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程") + logger.info(f"接受到停止信号,停止数据处理线程-外层") break @@ -1687,7 +1715,7 @@ def register_websocket(api_key): } response = requests.post(url, headers=headers) response_json = response.json() - logger.debug(f"register_websocket response: {response_json}") + # logger.debug(f"register_websocket response: {response_json}") wss_url = response_json.get("wss_url", None) return wss_url @@ -1786,6 +1814,8 @@ def chat_completions(): # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" + image_urls = [] + # 处理流式响应 def generate(): @@ -1821,6 +1851,9 @@ def chat_completions(): # 更新 conversation_id conversation_id = data[1] # print(f"收到会话id: {conversation_id}") + elif isinstance(data, tuple) and data[0] == 'image_url': + # 更新 image_url + image_urls.append(data[1]) elif data == 'data: [DONE]\n\n': # 接收到结束信号,退出循环 timestamp = int(time.time()) @@ -1845,9 +1878,11 @@ def chat_completions(): yield data break else: + # logger.debug(f"发出数据: {data}") yield data finally: + logger.debug(f"清理资源") stop_event.set() fetcher_thread.join() keep_alive_thread.join()