From d6583f89439c6231e5333f66dd2756866e3eaa53 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Sat, 3 Feb 2024 11:18:18 +0800 Subject: [PATCH] =?UTF-8?q?[fix]=20=E4=BC=98=E5=8C=96=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E6=98=AF=E5=90=A6=E4=B8=BAsse=E7=9A=84=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 120 +++++++++++++++++++++++--------------------------------- 1 file changed, 50 insertions(+), 70 deletions(-) diff --git a/main.py b/main.py index d35e3ca..8414cc7 100644 --- a/main.py +++ b/main.py @@ -1629,79 +1629,59 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m break upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model) upstream_wss_url = None - try: - upstream_response_json = upstream_response.json() - upstream_wss_url = upstream_response_json.get("wss_url", None) - upstream_response_id = upstream_response_json.get("response_id", None) - context["response_id"] = upstream_response_id - except json.JSONDecodeError: - pass - if upstream_wss_url is not None: - logger.debug(f"start wss...") - ws = websocket.WebSocketApp(wss_url, - on_message = on_message, - on_error = on_error, - on_close = on_close) - ws.on_open = on_open - ws.run_forever() - - logger.debug(f"end wss...") - elif upstream_response.status_code != 200: - logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") - complete_data = 'data: [DONE]\n\n' - timestamp = context["timestamp"] - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") - }, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - - q_data = complete_data - data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) - data_queue.put(q_data) - stop_event.set() - - else: - logger.error(f"检测到非wss响应...") + # 检查 Content-Type 是否为 SSE 响应 + content_type = upstream_response.headers.get('Content-Type') + # 判断content_type是否包含'text/event-stream' + if content_type and 'text/event-stream' in content_type: + logger.debug("上游响应为 SSE 响应") old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format) - # complete_data = 'data: [DONE]\n\n' - # timestamp = context["timestamp"] + else: + if upstream_response.status_code != 200: + logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") + complete_data = 'data: [DONE]\n\n' + timestamp = context["timestamp"] - # new_data = { - # "id": chat_message_id, - # "object": "chat.completion.chunk", - # "created": timestamp, - # "model": model, - # "choices": [ - # { - # "index": 0, - # "delta": { - # "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```") - # }, - # "finish_reason": None - # } - # ] - # } - # q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - # data_queue.put(q_data) + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + q_data = complete_data + data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) + data_queue.put(q_data) + stop_event.set() + return + try: + upstream_response_json = upstream_response.json() + upstream_wss_url = upstream_response_json.get("wss_url", None) + upstream_response_id = upstream_response_json.get("response_id", None) + context["response_id"] = upstream_response_id + except json.JSONDecodeError: + pass + if upstream_wss_url is not None: + logger.debug(f"start wss...") + ws = websocket.WebSocketApp(wss_url, + on_message = on_message, + on_error = on_error, + on_close = on_close) + ws.on_open = on_open + ws.run_forever() + + logger.debug(f"end wss...") - # q_data = complete_data - # data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")) - # data_queue.put(q_data) - # stop_event.set() def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format):