mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-17 00:24:13 +00:00
[feat] 同时支持wss和sse响应
This commit is contained in:
212
main.py
212
main.py
@@ -32,7 +32,7 @@ def load_config(file_path):
|
|||||||
|
|
||||||
CONFIG = load_config('./data/config.json')
|
CONFIG = load_config('./data/config.json')
|
||||||
|
|
||||||
LOG_LEVEL = CONFIG.get('log_level', 'DEBUG').upper()
|
LOG_LEVEL = CONFIG.get('log_level', 'INFO').upper()
|
||||||
NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true'
|
NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true'
|
||||||
|
|
||||||
# 使用 get 方法获取配置项,同时提供默认值
|
# 使用 get 方法获取配置项,同时提供默认值
|
||||||
@@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
|||||||
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||||
|
|
||||||
|
|
||||||
VERSION = '0.7.1'
|
VERSION = '0.7.2'
|
||||||
# VERSION = 'test'
|
# VERSION = 'test'
|
||||||
UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题'
|
UPDATE_INFO = '同时支持wss和sse响应'
|
||||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
@@ -1628,10 +1628,14 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
ws.close()
|
ws.close()
|
||||||
break
|
break
|
||||||
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
|
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
|
||||||
upstream_response_json = upstream_response.json()
|
upstream_wss_url = None
|
||||||
upstream_wss_url = upstream_response_json.get("wss_url", None)
|
try:
|
||||||
upstream_response_id = upstream_response_json.get("response_id", None)
|
upstream_response_json = upstream_response.json()
|
||||||
context["response_id"] = upstream_response_id
|
upstream_wss_url = upstream_response_json.get("wss_url", None)
|
||||||
|
upstream_response_id = upstream_response_json.get("response_id", None)
|
||||||
|
context["response_id"] = upstream_response_id
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
if upstream_wss_url is not None:
|
if upstream_wss_url is not None:
|
||||||
logger.debug(f"start wss...")
|
logger.debug(f"start wss...")
|
||||||
ws = websocket.WebSocketApp(wss_url,
|
ws = websocket.WebSocketApp(wss_url,
|
||||||
@@ -1642,9 +1646,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
ws.run_forever()
|
ws.run_forever()
|
||||||
|
|
||||||
logger.debug(f"end wss...")
|
logger.debug(f"end wss...")
|
||||||
|
elif upstream_response.status_code != 200:
|
||||||
else:
|
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
|
||||||
logger.error(f"注册 wss 失败")
|
|
||||||
complete_data = 'data: [DONE]\n\n'
|
complete_data = 'data: [DONE]\n\n'
|
||||||
timestamp = context["timestamp"]
|
timestamp = context["timestamp"]
|
||||||
|
|
||||||
@@ -1657,7 +1660,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
|
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
|
||||||
},
|
},
|
||||||
"finish_reason": None
|
"finish_reason": None
|
||||||
}
|
}
|
||||||
@@ -1667,10 +1670,193 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
|
|||||||
data_queue.put(q_data)
|
data_queue.put(q_data)
|
||||||
|
|
||||||
q_data = complete_data
|
q_data = complete_data
|
||||||
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
|
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
|
||||||
data_queue.put(q_data)
|
data_queue.put(q_data)
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"检测到非wss响应...")
|
||||||
|
old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)
|
||||||
|
# complete_data = 'data: [DONE]\n\n'
|
||||||
|
# timestamp = context["timestamp"]
|
||||||
|
|
||||||
|
# new_data = {
|
||||||
|
# "id": chat_message_id,
|
||||||
|
# "object": "chat.completion.chunk",
|
||||||
|
# "created": timestamp,
|
||||||
|
# "model": model,
|
||||||
|
# "choices": [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "delta": {
|
||||||
|
# "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
|
||||||
|
# },
|
||||||
|
# "finish_reason": None
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
|
||||||
|
# data_queue.put(q_data)
|
||||||
|
|
||||||
|
# q_data = complete_data
|
||||||
|
# data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
|
||||||
|
# data_queue.put(q_data)
|
||||||
|
# stop_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format):
|
||||||
|
all_new_text = ""
|
||||||
|
|
||||||
|
first_output = True
|
||||||
|
|
||||||
|
# 当前时间戳
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本
|
||||||
|
last_full_code = ""
|
||||||
|
last_full_code_result = ""
|
||||||
|
last_content_type = None # 用于记录上一个消息的内容类型
|
||||||
|
conversation_id = ''
|
||||||
|
citation_buffer = ""
|
||||||
|
citation_accumulating = False
|
||||||
|
file_output_buffer = ""
|
||||||
|
file_output_accumulating = False
|
||||||
|
execution_output_image_url_buffer = ""
|
||||||
|
execution_output_image_id_buffer = ""
|
||||||
|
try:
|
||||||
|
for chunk in upstream_response.iter_content(chunk_size=1024):
|
||||||
|
if stop_event.is_set():
|
||||||
|
logger.info(f"接受到停止信号,停止数据处理线程")
|
||||||
|
break
|
||||||
|
if chunk:
|
||||||
|
buffer += chunk.decode('utf-8')
|
||||||
|
# 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容
|
||||||
|
if "event: ping" in buffer:
|
||||||
|
if "data:" in buffer:
|
||||||
|
buffer = buffer.split("data:", 1)[1]
|
||||||
|
buffer = "data:" + buffer
|
||||||
|
# 使用正则表达式移除特定格式的字符串
|
||||||
|
# print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n'))
|
||||||
|
buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer)
|
||||||
|
# print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
while 'data:' in buffer and '\n\n' in buffer:
|
||||||
|
end_index = buffer.index('\n\n') + 2
|
||||||
|
complete_data, buffer = buffer[:end_index], buffer[end_index:]
|
||||||
|
# 解析 data 块
|
||||||
|
try:
|
||||||
|
data_json = json.loads(complete_data.replace('data: ', ''))
|
||||||
|
logger.debug(f"data_json: {data_json}")
|
||||||
|
# print(f"data_json: {data_json}")
|
||||||
|
all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# print("JSON 解析错误")
|
||||||
|
logger.info(f"发送数据: {complete_data}")
|
||||||
|
if complete_data == 'data: [DONE]\n\n':
|
||||||
|
logger.info(f"会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
|
if citation_buffer != "":
|
||||||
|
new_data = {
|
||||||
|
"id": chat_message_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": timestamp,
|
||||||
|
"model": message.get("metadata", {}).get("model_slug"),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": ''.join(citation_buffer)
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tmp = 'data: ' + json.dumps(new_data) + '\n\n'
|
||||||
|
# print(f"发送数据: {tmp}")
|
||||||
|
# 累积 new_text
|
||||||
|
all_new_text += citation_buffer
|
||||||
|
q_data = 'data: ' + json.dumps(new_data) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
if buffer:
|
||||||
|
# print(f"最后的数据: {buffer}")
|
||||||
|
# delete_conversation(conversation_id, api_key)
|
||||||
|
try:
|
||||||
|
buffer_json = json.loads(buffer)
|
||||||
|
logger.info(f"最后的缓存数据: {buffer_json}")
|
||||||
|
error_message = buffer_json.get("detail", {}).get("message", "未知错误")
|
||||||
|
error_data = {
|
||||||
|
"id": chat_message_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": timestamp,
|
||||||
|
"model": "error",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": ''.join("```\n" + error_message + "\n```")
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||||
|
logger.info(f"发送最后的数据: {tmp}")
|
||||||
|
# 累积 new_text
|
||||||
|
all_new_text += ''.join("```\n" + error_message + "\n```")
|
||||||
|
q_data = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
logger.info(f"会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
except:
|
||||||
|
# print("JSON 解析错误")
|
||||||
|
logger.info(f"发送最后的数据: {buffer}")
|
||||||
|
error_data = {
|
||||||
|
"id": chat_message_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": timestamp,
|
||||||
|
"model": "error",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": ''.join("```\n" + buffer + "\n```")
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||||
|
q_data = tmp
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
logger.info(f"会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception: {e}")
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
logger.info(f"会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
|
||||||
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
|
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
|
||||||
all_new_text = ""
|
all_new_text = ""
|
||||||
@@ -1715,7 +1901,7 @@ def register_websocket(api_key):
|
|||||||
}
|
}
|
||||||
response = requests.post(url, headers=headers)
|
response = requests.post(url, headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
# logger.debug(f"register_websocket response: {response_json}")
|
logger.debug(f"register_websocket response: {response_json}")
|
||||||
wss_url = response_json.get("wss_url", None)
|
wss_url = response_json.get("wss_url", None)
|
||||||
return wss_url
|
return wss_url
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user