[fix] 修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题

This commit is contained in:
Wizerd
2024-02-02 22:01:26 +08:00
parent f314ce9506
commit 7bab632262

49
main.py
View File

@@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
# PANDORA_UPLOAD_URL = 'files.pandoranext.com' # PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.7.0' VERSION = '0.7.1'
# VERSION = 'test' # VERSION = 'test'
UPDATE_INFO = '适配WSS输出' UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题'
# UPDATE_INFO = '【仅供临时测试使用】 ' # UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context(): with app.app_context():
@@ -1578,12 +1578,13 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
def on_message(ws, message): def on_message(ws, message):
# logger.debug(f"on_message: {message}") # logger.debug(f"on_message: {message}")
if stop_event.is_set(): if stop_event.is_set():
logger.info(f"接受到停止信号,停止数据处理线程") logger.info(f"接受到停止信号,停止 Websocket 处理线程")
ws.close()
return return
result_json = json.loads(message) result_json = json.loads(message)
result_id = result_json.get('response_id', '') result_id = result_json.get('response_id', '')
if result_id != context["response_id"]: if result_id != context["response_id"]:
logger.info(f"response_id 不匹配,忽略") logger.debug(f"response_id 不匹配,忽略")
return return
body = result_json.get('body', '') body = result_json.get('body', '')
# logger.debug("wss result: " + str(result_json)) # logger.debug("wss result: " + str(result_json))
@@ -1642,6 +1643,34 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
logger.debug(f"end wss...") logger.debug(f"end wss...")
else:
logger.error(f"注册 wss 失败")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]
new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
},
"finish_reason": None
}
]
}
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)
q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
all_new_text = "" all_new_text = ""
@@ -1671,12 +1700,11 @@ def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_i
# 如果存在 wss_url使用 WebSocket 连接获取数据 # 如果存在 wss_url使用 WebSocket 连接获取数据
if wss_url:
process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages) process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)
while True: while True:
if stop_event.is_set(): if stop_event.is_set():
logger.info(f"接受到停止信号,停止数据处理线程") logger.info(f"接受到停止信号,停止数据处理线程-外层")
break break
@@ -1687,7 +1715,7 @@ def register_websocket(api_key):
} }
response = requests.post(url, headers=headers) response = requests.post(url, headers=headers)
response_json = response.json() response_json = response.json()
logger.debug(f"register_websocket response: {response_json}") # logger.debug(f"register_websocket response: {response_json}")
wss_url = response_json.get("wss_url", None) wss_url = response_json.get("wss_url", None)
return wss_url return wss_url
@@ -1786,6 +1814,8 @@ def chat_completions():
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
all_new_text = "" all_new_text = ""
image_urls = []
# 处理流式响应 # 处理流式响应
def generate(): def generate():
@@ -1821,6 +1851,9 @@ def chat_completions():
# 更新 conversation_id # 更新 conversation_id
conversation_id = data[1] conversation_id = data[1]
# print(f"收到会话id: {conversation_id}") # print(f"收到会话id: {conversation_id}")
elif isinstance(data, tuple) and data[0] == 'image_url':
# 更新 image_url
image_urls.append(data[1])
elif data == 'data: [DONE]\n\n': elif data == 'data: [DONE]\n\n':
# 接收到结束信号,退出循环 # 接收到结束信号,退出循环
timestamp = int(time.time()) timestamp = int(time.time())
@@ -1845,9 +1878,11 @@ def chat_completions():
yield data yield data
break break
else: else:
# logger.debug(f"发出数据: {data}")
yield data yield data
finally: finally:
logger.debug(f"清理资源")
stop_event.set() stop_event.set()
fetcher_thread.join() fetcher_thread.join()
keep_alive_thread.join() keep_alive_thread.join()