[feat] 同时支持wss和sse响应

This commit is contained in:
Wizerd
2024-02-03 10:45:15 +08:00
parent e4fc4458a5
commit 5b61e0f2dd

212
main.py
View File

@@ -32,7 +32,7 @@ def load_config(file_path):
CONFIG = load_config('./data/config.json')
LOG_LEVEL = CONFIG.get('log_level', 'DEBUG').upper()
LOG_LEVEL = CONFIG.get('log_level', 'INFO').upper()
NEED_LOG_TO_FILE = CONFIG.get('need_log_to_file', 'true').lower() == 'true'
# 使用 get 方法获取配置项,同时提供默认值
@@ -197,9 +197,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.7.1'
VERSION = '0.7.2'
# VERSION = 'test'
UPDATE_INFO = '修复了流式响应绘图时的错误,以及上游接口出错时内存溢出的问题'
UPDATE_INFO = '同时支持wss和sse响应'
# UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context():
@@ -1628,10 +1628,14 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
ws.close()
break
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
upstream_response_json = upstream_response.json()
upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
upstream_wss_url = None
try:
upstream_response_json = upstream_response.json()
upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
except json.JSONDecodeError:
pass
if upstream_wss_url is not None:
logger.debug(f"start wss...")
ws = websocket.WebSocketApp(wss_url,
@@ -1642,9 +1646,8 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
ws.run_forever()
logger.debug(f"end wss...")
else:
logger.error(f"注册 wss 失败")
elif upstream_response.status_code != 200:
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]
@@ -1657,7 +1660,7 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
},
"finish_reason": None
}
@@ -1667,10 +1670,193 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
data_queue.put(q_data)
q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
else:
logger.error(f"检测到非wss响应...")
old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)
# complete_data = 'data: [DONE]\n\n'
# timestamp = context["timestamp"]
# new_data = {
# "id": chat_message_id,
# "object": "chat.completion.chunk",
# "created": timestamp,
# "model": model,
# "choices": [
# {
# "index": 0,
# "delta": {
# "content": ''.join("```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```")
# },
# "finish_reason": None
# }
# ]
# }
# q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
# data_queue.put(q_data)
# q_data = complete_data
# data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Websocket register fail...\"\n}\n```"))
# data_queue.put(q_data)
# stop_event.set()
def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format):
all_new_text = ""
first_output = True
# 当前时间戳
timestamp = int(time.time())
buffer = ""
last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本
last_full_code = ""
last_full_code_result = ""
last_content_type = None # 用于记录上一个消息的内容类型
conversation_id = ''
citation_buffer = ""
citation_accumulating = False
file_output_buffer = ""
file_output_accumulating = False
execution_output_image_url_buffer = ""
execution_output_image_id_buffer = ""
try:
for chunk in upstream_response.iter_content(chunk_size=1024):
if stop_event.is_set():
logger.info(f"接受到停止信号,停止数据处理线程")
break
if chunk:
buffer += chunk.decode('utf-8')
# 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容
if "event: ping" in buffer:
if "data:" in buffer:
buffer = buffer.split("data:", 1)[1]
buffer = "data:" + buffer
# 使用正则表达式移除特定格式的字符串
# print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n'))
buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer)
# print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n'))
while 'data:' in buffer and '\n\n' in buffer:
end_index = buffer.index('\n\n') + 2
complete_data, buffer = buffer[:end_index], buffer[end_index:]
# 解析 data 块
try:
data_json = json.loads(complete_data.replace('data: ', ''))
logger.debug(f"data_json: {data_json}")
# print(f"data_json: {data_json}")
all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text)
except json.JSONDecodeError:
# print("JSON 解析错误")
logger.info(f"发送数据: {complete_data}")
if complete_data == 'data: [DONE]\n\n':
logger.info(f"会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
if stop_event.is_set():
break
if citation_buffer != "":
new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": message.get("metadata", {}).get("model_slug"),
"choices": [
{
"index": 0,
"delta": {
"content": ''.join(citation_buffer)
},
"finish_reason": None
}
]
}
tmp = 'data: ' + json.dumps(new_data) + '\n\n'
# print(f"发送数据: {tmp}")
# 累积 new_text
all_new_text += citation_buffer
q_data = 'data: ' + json.dumps(new_data) + '\n\n'
data_queue.put(q_data)
last_data_time[0] = time.time()
if buffer:
# print(f"最后的数据: {buffer}")
# delete_conversation(conversation_id, api_key)
try:
buffer_json = json.loads(buffer)
logger.info(f"最后的缓存数据: {buffer_json}")
error_message = buffer_json.get("detail", {}).get("message", "未知错误")
error_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": "error",
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```\n" + error_message + "\n```")
},
"finish_reason": None
}
]
}
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
logger.info(f"发送最后的数据: {tmp}")
# 累积 new_text
all_new_text += ''.join("```\n" + error_message + "\n```")
q_data = 'data: ' + json.dumps(error_data) + '\n\n'
data_queue.put(q_data)
last_data_time[0] = time.time()
complete_data = 'data: [DONE]\n\n'
logger.info(f"会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
except:
# print("JSON 解析错误")
logger.info(f"发送最后的数据: {buffer}")
error_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": "error",
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```\n" + buffer + "\n```")
},
"finish_reason": None
}
]
}
tmp = 'data: ' + json.dumps(error_data) + '\n\n'
q_data = tmp
data_queue.put(q_data)
last_data_time[0] = time.time()
complete_data = 'data: [DONE]\n\n'
logger.info(f"会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
except Exception as e:
logger.error(f"Exception: {e}")
complete_data = 'data: [DONE]\n\n'
logger.info(f"会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages):
all_new_text = ""
@@ -1715,7 +1901,7 @@ def register_websocket(api_key):
}
response = requests.post(url, headers=headers)
response_json = response.json()
# logger.debug(f"register_websocket response: {response_json}")
logger.debug(f"register_websocket response: {response_json}")
wss_url = response_json.get("wss_url", None)
return wss_url