mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-16 08:10:21 +00:00
[feat] 同时支持wss和sse响应
This commit is contained in:
212
main.py
212
main.py
@@ -32,7 +32,7 @@ def load_config(file_path):
|
||||
|
||||
CONFIG = load_config('./data/config.json')
|
||||
|
||||
LOG_LEVEL = CONFIG.get('log_level', 'DEBUG').upper()
|
||||
LOG_LEVEL = CONFIG.get('log_level', 'INFO').upper()
|
||||
NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true'
|
||||
|
||||
# 使用 get 方法获取配置项,同时提供默认值
|
||||
@@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
||||
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||
|
||||
|
||||
VERSION = '0.7.1'
|
||||
VERSION = '0.7.2'
|
||||
# VERSION = 'test'
|
||||
UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题'
|
||||
UPDATE_INFO = '同时支持wss和sse响应'
|
||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||
|
||||
with app.app_context():
|
||||
@@ -1628,10 +1628,14 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
||||
ws.close()
|
||||
break
|
||||
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
|
||||
upstream_response_json = upstream_response.json()
|
||||
upstream_wss_url = upstream_response_json.get("wss_url", None)
|
||||
upstream_response_id = upstream_response_json.get("response_id", None)
|
||||
context["response_id"] = upstream_response_id
|
||||
upstream_wss_url = None
|
||||
try:
|
||||
upstream_response_json = upstream_response.json()
|
||||
upstream_wss_url = upstream_response_json.get("wss_url", None)
|
||||
upstream_response_id = upstream_response_json.get("response_id", None)
|
||||
context["response_id"] = upstream_response_id
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if upstream_wss_url is not None:
|
||||
logger.debug(f"start wss...")
|
||||
ws = websocket.WebSocketApp(wss_url,
|
||||
@@ -1642,9 +1646,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
||||
ws.run_forever()
|
||||
|
||||
logger.debug(f"end wss...")
|
||||
|
||||
else:
|
||||
logger.error(f"注册 wss 失败")
|
||||
elif upstream_response.status_code != 200:
|
||||
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
|
||||
complete_data = 'data: [DONE]\n\n'
|
||||
timestamp = context["timestamp"]
|
||||
|
||||
@@ -1657,7 +1660,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
|
||||
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
@@ -1667,10 +1670,193 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
||||
data_queue.put(q_data)
|
||||
|
||||
q_data = complete_data
|
||||
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
|
||||
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
|
||||
data_queue.put(q_data)
|
||||
stop_event.set()
|
||||
|
||||
else:
|
||||
logger.error(f"检测到非wss响应...")
|
||||
old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)
|
||||
# complete_data = 'data: [DONE]\n\n'
|
||||
# timestamp = context["timestamp"]
|
||||
|
||||
# new_data = {
|
||||
# "id": chat_message_id,
|
||||
# "object": "chat.completion.chunk",
|
||||
# "created": timestamp,
|
||||
# "model": model,
|
||||
# "choices": [
|
||||
# {
|
||||
# "index": 0,
|
||||
# "delta": {
|
||||
# "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
|
||||
# },
|
||||
# "finish_reason": None
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
|
||||
# data_queue.put(q_data)
|
||||
|
||||
# q_data = complete_data
|
||||
# data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
|
||||
# data_queue.put(q_data)
|
||||
# stop_event.set()
|
||||
|
||||
|
||||
def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format):
|
||||
all_new_text = ""
|
||||
|
||||
first_output = True
|
||||
|
||||
# 当前时间戳
|
||||
timestamp = int(time.time())
|
||||
|
||||
buffer = ""
|
||||
last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本
|
||||
last_full_code = ""
|
||||
last_full_code_result = ""
|
||||
last_content_type = None # 用于记录上一个消息的内容类型
|
||||
conversation_id = ''
|
||||
citation_buffer = ""
|
||||
citation_accumulating = False
|
||||
file_output_buffer = ""
|
||||
file_output_accumulating = False
|
||||
execution_output_image_url_buffer = ""
|
||||
execution_output_image_id_buffer = ""
|
||||
try:
|
||||
for chunk in upstream_response.iter_content(chunk_size=1024):
|
||||
if stop_event.is_set():
|
||||
logger.info(f"接受到停止信号,停止数据处理线程")
|
||||
break
|
||||
if chunk:
|
||||
buffer += chunk.decode('utf-8')
|
||||
# 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容
|
||||
if "event: ping" in buffer:
|
||||
if "data:" in buffer:
|
||||
buffer = buffer.split("data:", 1)[1]
|
||||
buffer = "data:" + buffer
|
||||
# 使用正则表达式移除特定格式的字符串
|
||||
# print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n'))
|
||||
buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer)
|
||||
# print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n'))
|
||||
|
||||
|
||||
|
||||
while 'data:' in buffer and '\n\n' in buffer:
|
||||
end_index = buffer.index('\n\n') + 2
|
||||
complete_data, buffer = buffer[:end_index], buffer[end_index:]
|
||||
# 解析 data 块
|
||||
try:
|
||||
data_json = json.loads(complete_data.replace('data: ', ''))
|
||||
logger.debug(f"data_json: {data_json}")
|
||||
# print(f"data_json: {data_json}")
|
||||
all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text)
|
||||
except json.JSONDecodeError:
|
||||
# print("JSON 解析错误")
|
||||
logger.info(f"发送数据: {complete_data}")
|
||||
if complete_data == 'data: [DONE]\n\n':
|
||||
logger.info(f"会话结束")
|
||||
q_data = complete_data
|
||||
data_queue.put(('all_new_text', all_new_text))
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
if stop_event.is_set():
|
||||
break
|
||||
if citation_buffer != "":
|
||||
new_data = {
|
||||
"id": chat_message_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": message.get("metadata", {}).get("model_slug"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": ''.join(citation_buffer)
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
tmp = 'data: ' + json.dumps(new_data) + '\n\n'
|
||||
# print(f"发送数据: {tmp}")
|
||||
# 累积 new_text
|
||||
all_new_text += citation_buffer
|
||||
q_data = 'data: ' + json.dumps(new_data) + '\n\n'
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
if buffer:
|
||||
# print(f"最后的数据: {buffer}")
|
||||
# delete_conversation(conversation_id, api_key)
|
||||
try:
|
||||
buffer_json = json.loads(buffer)
|
||||
logger.info(f"最后的缓存数据: {buffer_json}")
|
||||
error_message = buffer_json.get("detail", {}).get("message", "未知错误")
|
||||
error_data = {
|
||||
"id": chat_message_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": "error",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": ''.join("```\n" + error_message + "\n```")
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||
logger.info(f"发送最后的数据: {tmp}")
|
||||
# 累积 new_text
|
||||
all_new_text += ''.join("```\n" + error_message + "\n```")
|
||||
q_data = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
complete_data = 'data: [DONE]\n\n'
|
||||
logger.info(f"会话结束")
|
||||
q_data = complete_data
|
||||
data_queue.put(('all_new_text', all_new_text))
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
except:
|
||||
# print("JSON 解析错误")
|
||||
logger.info(f"发送最后的数据: {buffer}")
|
||||
error_data = {
|
||||
"id": chat_message_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": "error",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": ''.join("```\n" + buffer + "\n```")
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||
q_data = tmp
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
complete_data = 'data: [DONE]\n\n'
|
||||
logger.info(f"会话结束")
|
||||
q_data = complete_data
|
||||
data_queue.put(('all_new_text', all_new_text))
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
complete_data = 'data: [DONE]\n\n'
|
||||
logger.info(f"会话结束")
|
||||
q_data = complete_data
|
||||
data_queue.put(('all_new_text', all_new_text))
|
||||
data_queue.put(q_data)
|
||||
last_data_time[0] = time.time()
|
||||
|
||||
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
|
||||
all_new_text = ""
|
||||
@@ -1715,7 +1901,7 @@ def register_websocket(api_key):
|
||||
}
|
||||
response = requests.post(url, headers=headers)
|
||||
response_json = response.json()
|
||||
# logger.debug(f"register_websocket response: {response_json}")
|
||||
logger.debug(f"register_websocket response: {response_json}")
|
||||
wss_url = response_json.get("wss_url", None)
|
||||
return wss_url
|
||||
|
||||
|
Reference in New Issue
Block a user