mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-14 23:22:42 +00:00
[fix] 修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题
This commit is contained in:
51
main.py
51
main.py
@@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
|||||||
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||||
|
|
||||||
|
|
||||||
VERSION = '0.7.0'
|
VERSION = '0.7.1'
|
||||||
# VERSION = 'test'
|
# VERSION = 'test'
|
||||||
UPDATE_INFO = '适配WSS输出'
|
UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题'
|
||||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
@@ -1578,12 +1578,13 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
def on_message(ws, message):
|
def on_message(ws, message):
|
||||||
# logger.debug(f"on_message: {message}")
|
# logger.debug(f"on_message: {message}")
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
logger.info(f"接受到停止信号,停止数据处理线程")
|
logger.info(f"接受到停止信号,停止 Websocket 处理线程")
|
||||||
|
ws.close()
|
||||||
return
|
return
|
||||||
result_json = json.loads(message)
|
result_json = json.loads(message)
|
||||||
result_id = result_json.get('response_id', '')
|
result_id = result_json.get('response_id', '')
|
||||||
if result_id != context["response_id"]:
|
if result_id != context["response_id"]:
|
||||||
logger.info(f"response_id 不匹配,忽略")
|
logger.debug(f"response_id 不匹配,忽略")
|
||||||
return
|
return
|
||||||
body = result_json.get('body', '')
|
body = result_json.get('body', '')
|
||||||
# logger.debug("wss result: " + str(result_json))
|
# logger.debug("wss result: " + str(result_json))
|
||||||
@@ -1642,6 +1643,34 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
|
|
||||||
logger.debug(f"end wss...")
|
logger.debug(f"end wss...")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"注册 wss 失败")
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
timestamp = context["timestamp"]
|
||||||
|
|
||||||
|
new_data = {
|
||||||
|
"id": chat_message_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": timestamp,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
|
||||||
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
|
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
|
||||||
all_new_text = ""
|
all_new_text = ""
|
||||||
@@ -1671,12 +1700,11 @@ def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_i
|
|||||||
|
|
||||||
# 如果存在 wss_url,使用 WebSocket 连接获取数据
|
# 如果存在 wss_url,使用 WebSocket 连接获取数据
|
||||||
|
|
||||||
if wss_url:
|
process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)
|
||||||
process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
logger.info(f"接受到停止信号,停止数据处理线程")
|
logger.info(f"接受到停止信号,停止数据处理线程-外层")
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1687,7 +1715,7 @@ def register_websocket(api_key):
|
|||||||
}
|
}
|
||||||
response = requests.post(url, headers=headers)
|
response = requests.post(url, headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
logger.debug(f"register_websocket response: {response_json}")
|
# logger.debug(f"register_websocket response: {response_json}")
|
||||||
wss_url = response_json.get("wss_url", None)
|
wss_url = response_json.get("wss_url", None)
|
||||||
return wss_url
|
return wss_url
|
||||||
|
|
||||||
@@ -1786,6 +1814,8 @@ def chat_completions():
|
|||||||
|
|
||||||
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
|
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
|
||||||
all_new_text = ""
|
all_new_text = ""
|
||||||
|
image_urls = []
|
||||||
|
|
||||||
|
|
||||||
# 处理流式响应
|
# 处理流式响应
|
||||||
def generate():
|
def generate():
|
||||||
@@ -1821,6 +1851,9 @@ def chat_completions():
|
|||||||
# 更新 conversation_id
|
# 更新 conversation_id
|
||||||
conversation_id = data[1]
|
conversation_id = data[1]
|
||||||
# print(f"收到会话id: {conversation_id}")
|
# print(f"收到会话id: {conversation_id}")
|
||||||
|
elif isinstance(data, tuple) and data[0] == 'image_url':
|
||||||
|
# 更新 image_url
|
||||||
|
image_urls.append(data[1])
|
||||||
elif data == 'data: [DONE]\n\n':
|
elif data == 'data: [DONE]\n\n':
|
||||||
# 接收到结束信号,退出循环
|
# 接收到结束信号,退出循环
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
@@ -1845,9 +1878,11 @@ def chat_completions():
|
|||||||
yield data
|
yield data
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# logger.debug(f"发出数据: {data}")
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
logger.debug(f"清理资源")
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
fetcher_thread.join()
|
fetcher_thread.join()
|
||||||
keep_alive_thread.join()
|
keep_alive_thread.join()
|
||||||
|
Reference in New Issue
Block a user