diff --git a/main.py b/main.py index f7be398..ce381f9 100644 --- a/main.py +++ b/main.py @@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.0' +VERSION = '0.7.1' # VERSION = 'test' -UPDATE_INFO = '适配WSS输出' +UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -1578,12 +1578,13 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m def on_message(ws, message): # logger.debug(f"on_message: {message}") if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程") + logger.info(f"接受到停止信号,停止 Websocket 处理线程") + ws.close() return result_json = json.loads(message) result_id = result_json.get('response_id', '') if result_id != context["response_id"]: - logger.info(f"response_id 不匹配,忽略") + logger.debug(f"response_id 不匹配,忽略") return body = result_json.get('body', '') # logger.debug("wss result: " + str(result_json)) @@ -1642,6 +1643,34 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m logger.debug(f"end wss...") + else: + logger.error(f"注册 wss 失败") + complete_data = 'data: [DONE]\n\n' + timestamp = context["timestamp"] + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + q_data = complete_data + data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")) + data_queue.put(q_data) + stop_event.set() + def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): all_new_text = "" @@ -1671,12 +1700,11 @@ def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_i # 如果存在 wss_url,使用 WebSocket 连接获取数据 - if wss_url: - process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) + process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) while True: if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程") + logger.info(f"接受到停止信号,停止数据处理线程-外层") break @@ -1687,7 +1715,7 @@ def register_websocket(api_key): } response = requests.post(url, headers=headers) response_json = response.json() - logger.debug(f"register_websocket response: {response_json}") + # logger.debug(f"register_websocket response: {response_json}") wss_url = response_json.get("wss_url", None) return wss_url @@ -1786,6 +1814,8 @@ def chat_completions(): # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" + image_urls = [] + # 处理流式响应 def generate(): @@ -1821,6 +1851,9 @@ def chat_completions(): # 更新 conversation_id conversation_id = data[1] # print(f"收到会话id: {conversation_id}") + elif isinstance(data, tuple) and data[0] == 'image_url': + # 更新 image_url + image_urls.append(data[1]) elif data == 'data: [DONE]\n\n': # 接收到结束信号,退出循环 timestamp = int(time.time()) @@ -1845,9 +1878,11 @@ def chat_completions(): yield data break else: + # logger.debug(f"发出数据: {data}") yield data finally: + logger.debug(f"清理资源") stop_event.set() fetcher_thread.join() keep_alive_thread.join()