[fix] 优化检测是否为sse的方式

This commit is contained in:
Wizerd
2024-02-03 11:18:18 +08:00
parent 5b61e0f2dd
commit d6583f8943

120
main.py
View File

@@ -1629,79 +1629,59 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
break
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
upstream_wss_url = None
try:
upstream_response_json = upstream_response.json()
upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
except json.JSONDecodeError:
pass
if upstream_wss_url is not None:
logger.debug(f"start wss...")
ws = websocket.WebSocketApp(wss_url,
on_message = on_message,
on_error = on_error,
on_close = on_close)
ws.on_open = on_open
ws.run_forever()
logger.debug(f"end wss...")
elif upstream_response.status_code != 200:
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]
new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
},
"finish_reason": None
}
]
}
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)
q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
else:
logger.error(f"检测到非wss响应...")
# 检查 Content-Type 是否为 SSE 响应
content_type = upstream_response.headers.get('Content-Type')
# 判断content_type是否包含'text/event-stream'
if content_type and 'text/event-stream' in content_type:
logger.debug("上游响应为 SSE 响应")
old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)
# complete_data = 'data: [DONE]\n\n'
# timestamp = context["timestamp"]
else:
if upstream_response.status_code != 200:
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]
# new_data = {
# "id": chat_message_id,
# "object": "chat.completion.chunk",
# "created": timestamp,
# "model": model,
# "choices": [
# {
# "index": 0,
# "delta": {
# "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
# },
# "finish_reason": None
# }
# ]
# }
# q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
# data_queue.put(q_data)
new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
},
"finish_reason": None
}
]
}
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)
q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
return
try:
upstream_response_json = upstream_response.json()
upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
except json.JSONDecodeError:
pass
if upstream_wss_url is not None:
logger.debug(f"start wss...")
ws = websocket.WebSocketApp(wss_url,
on_message = on_message,
on_error = on_error,
on_close = on_close)
ws.on_open = on_open
ws.run_forever()
logger.debug(f"end wss...")
# q_data = complete_data
# data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
# data_queue.put(q_data)
# stop_event.set()
def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format):