[Fix] 重构代码,优化输出稳定性,增加接口保活输出

This commit is contained in:
Wizerd
2023-12-19 18:04:51 +08:00
parent 25177392c7
commit e0256d6607
2 changed files with 401 additions and 294 deletions

View File

@@ -20,10 +20,12 @@
- [ ] 支持 日志等级划分 - [ ] 支持 日志等级划分
- [ ] 支持 接口保活 - [x] 支持 接口保活
- [ ] 支持 自定义接口前缀 - [ ] 支持 自定义接口前缀
- [ ] 优化 偶现的【0†source】引用bug
## 注意 ## 注意
> [!CAUTION] > [!CAUTION]

195
main.py
View File

@@ -10,6 +10,8 @@ from datetime import datetime
from PIL import Image from PIL import Image
import io import io
import re import re
import threading
from queue import Queue, Empty
def generate_unique_id(prefix): def generate_unique_id(prefix):
# 生成一个随机的 UUID # 生成一个随机的 UUID
@@ -103,9 +105,9 @@ PROXY_API_PREFIX = os.getenv('PROXY_API_PREFIX', '')
UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '') UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '')
KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '') KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
VERSION = '0.1.6' VERSION = '0.1.7'
# VERSION = 'test' # VERSION = 'test'
UPDATE_INFO = '对于自带的三个模型,支持多个名字映射到同一个模型' UPDATE_INFO = '重构代码,优化输出稳定性'
# UPDATE_INFO = '【仅供临时测试使用】 ' # UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context(): with app.app_context():
@@ -382,35 +384,9 @@ def replace_complete_citation(text, citations):
return replaced_text, remaining_text, is_potential_citation return replaced_text, remaining_text, is_potential_citation
# 定义 Flask 路由 def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model):
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
print(f"[{datetime.now()}] New Request")
data = request.json
messages = data.get('messages')
model = data.get('model')
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list:
return jsonify({"error": "model is not accessible"}), 401
stream = data.get('stream', False)
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Authorization header is missing or invalid"}), 401
api_key = auth_header.split(' ')[1]
print(f"api_key: {api_key}")
upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model)
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
all_new_text = "" all_new_text = ""
# 处理流式响应
def generate():
nonlocal all_new_text # 引用外部变量
chat_message_id = generate_unique_id("chatcmpl")
# 当前时间戳 # 当前时间戳
timestamp = int(time.time()) timestamp = int(time.time())
@@ -453,8 +429,8 @@ def chat_completions():
content = message.get("content", {}) content = message.get("content", {})
role = message.get("author", {}).get("role") role = message.get("author", {}).get("role")
content_type = content.get("content_type") content_type = content.get("content_type")
print(f"content_type: {content_type}") # print(f"content_type: {content_type}")
print(f"last_content_type: {last_content_type}") # print(f"last_content_type: {last_content_type}")
metadata = {} metadata = {}
citations = [] citations = []
@@ -469,7 +445,9 @@ def chat_completions():
continue continue
try: try:
conversation_id = data_json.get("conversation_id") conversation_id = data_json.get("conversation_id")
print(f"conversation_id: {conversation_id}") # print(f"conversation_id: {conversation_id}")
if conversation_id:
data_queue.put(('conversation_id', conversation_id))
except: except:
pass pass
# 只获取新的部分 # 只获取新的部分
@@ -557,13 +535,13 @@ def chat_completions():
if "\u3010" in new_text and not citation_accumulating: if "\u3010" in new_text and not citation_accumulating:
citation_accumulating = True citation_accumulating = True
citation_buffer = citation_buffer + new_text citation_buffer = citation_buffer + new_text
print(f"开始积累引用: {citation_buffer}") # print(f"开始积累引用: {citation_buffer}")
elif citation_accumulating: elif citation_accumulating:
citation_buffer += new_text citation_buffer += new_text
print(f"积累引用: {citation_buffer}") # print(f"积累引用: {citation_buffer}")
if citation_accumulating: if citation_accumulating:
if is_valid_citation_format(citation_buffer): if is_valid_citation_format(citation_buffer):
print(f"合法格式: {citation_buffer}") # print(f"合法格式: {citation_buffer}")
# 继续积累 # 继续积累
if is_complete_citation_format(citation_buffer): if is_complete_citation_format(citation_buffer):
@@ -577,12 +555,12 @@ def chat_completions():
else: else:
citation_accumulating = False citation_accumulating = False
citation_buffer = "" citation_buffer = ""
print(f"替换完整的引用格式: {new_text}") # print(f"替换完整的引用格式: {new_text}")
else: else:
continue continue
else: else:
# 不是合法格式,放弃积累并响应 # 不是合法格式,放弃积累并响应
print(f"不合法格式: {citation_buffer}") # print(f"不合法格式: {citation_buffer}")
new_text = citation_buffer new_text = citation_buffer
citation_accumulating = False citation_accumulating = False
citation_buffer = "" citation_buffer = ""
@@ -627,12 +605,12 @@ def chat_completions():
if content_type != None: if content_type != None:
last_content_type = content_type if role != "user" else last_content_type last_content_type = content_type if role != "user" else last_content_type
model_slug = message.get("metadata", {}).get("model_slug") or model
new_data = { new_data = {
"id": chat_message_id, "id": chat_message_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": timestamp, "created": timestamp,
"model": message.get("metadata", {}).get("model_slug"), "model": model_slug,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
@@ -649,13 +627,22 @@ def chat_completions():
# print(f"[{datetime.now()}] 发送数据: {tmp}") # print(f"[{datetime.now()}] 发送数据: {tmp}")
# 累积 new_text # 累积 new_text
all_new_text += new_text all_new_text += new_text
yield 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)
last_data_time[0] = time.time()
if stop_event.is_set():
break
except json.JSONDecodeError: except json.JSONDecodeError:
# print("JSON 解析错误") # print("JSON 解析错误")
print(f"[{datetime.now()}] 发送数据: {complete_data}") print(f"[{datetime.now()}] 发送数据: {complete_data}")
if complete_data == 'data: [DONE]\n\n': if complete_data == 'data: [DONE]\n\n':
print(f"[{datetime.now()}] 会话结束") print(f"[{datetime.now()}] 会话结束")
yield complete_data q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
if stop_event.is_set():
break
if citation_buffer != "": if citation_buffer != "":
new_data = { new_data = {
"id": chat_message_id, "id": chat_message_id,
@@ -676,10 +663,12 @@ def chat_completions():
# print(f"[{datetime.now()}] 发送数据: {tmp}") # print(f"[{datetime.now()}] 发送数据: {tmp}")
# 累积 new_text # 累积 new_text
all_new_text += citation_buffer all_new_text += citation_buffer
yield 'data: ' + json.dumps(new_data) + '\n\n' q_data = 'data: ' + json.dumps(new_data) + '\n\n'
data_queue.put(q_data)
last_data_time[0] = time.time()
if buffer: if buffer:
# print(f"[{datetime.now()}] 最后的数据: {buffer}") # print(f"[{datetime.now()}] 最后的数据: {buffer}")
delete_conversation(conversation_id, api_key) # delete_conversation(conversation_id, api_key)
try: try:
buffer_json = json.loads(buffer) buffer_json = json.loads(buffer)
error_message = buffer_json.get("detail", {}).get("message", "未知错误") error_message = buffer_json.get("detail", {}).get("message", "未知错误")
@@ -702,12 +691,128 @@ def chat_completions():
print(f"[{datetime.now()}] 发送最后的数据: {tmp}") print(f"[{datetime.now()}] 发送最后的数据: {tmp}")
# 累积 new_text # 累积 new_text
all_new_text += ''.join("```\n" + error_message + "\n```") all_new_text += ''.join("```\n" + error_message + "\n```")
yield 'data: ' + json.dumps(error_data) + '\n\n' q_data = 'data: ' + json.dumps(error_data) + '\n\n'
data_queue.put(q_data)
last_data_time[0] = time.time()
complete_data = 'data: [DONE]\n\n'
print(f"[{datetime.now()}] 会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
except: except:
# print("JSON 解析错误") # print("JSON 解析错误")
print(f"[{datetime.now()}] 发送最后的数据: {buffer}") print(f"[{datetime.now()}] 发送最后的数据: {buffer}")
yield buffer q_data = buffer
data_queue.put(q_data)
last_data_time[0] = time.time()
complete_data = 'data: [DONE]\n\n'
print(f"[{datetime.now()}] 会话结束")
q_data = complete_data
data_queue.put(('all_new_text', all_new_text))
data_queue.put(q_data)
last_data_time[0] = time.time()
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
while not stop_event.is_set():
if time.time() - last_data_time[0] >=1:
print(f"[{datetime.now()}] 发送保活消息")
# 当前时间戳
timestamp = int(time.time())
new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''
},
"finish_reason": None
}
]
}
queue.put(f'data: {json.dumps(new_data)}\n\n') # 发送保活消息
last_data_time[0] = time.time()
time.sleep(1)
import threading
import time
# 定义 Flask 路由
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
print(f"[{datetime.now()}] New Request")
data = request.json
messages = data.get('messages')
model = data.get('model')
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list:
return jsonify({"error": "model is not accessible"}), 401
stream = data.get('stream', False)
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Authorization header is missing or invalid"}), 401
api_key = auth_header.split(' ')[1]
print(f"api_key: {api_key}")
upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model)
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
all_new_text = ""
# 处理流式响应
def generate():
nonlocal all_new_text # 引用外部变量
data_queue = Queue()
stop_event = threading.Event()
last_data_time = [time.time()]
chat_message_id = generate_unique_id("chatcmpl")
conversation_id_print_tag = False
conversation_id = ''
# 启动数据处理线程
fetcher_thread = threading.Thread(target=data_fetcher, args=(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model))
fetcher_thread.start()
# 启动保活线程
keep_alive_thread = threading.Thread(target=keep_alive, args=(last_data_time, stop_event, data_queue, model, chat_message_id))
keep_alive_thread.start()
try:
while True:
data = data_queue.get()
if isinstance(data, tuple) and data[0] == 'all_new_text':
# 更新 all_new_text
print(f"[{datetime.now()}] 完整消息: {data[1]}")
all_new_text += data[1]
elif isinstance(data, tuple) and data[0] == 'conversation_id':
if conversation_id_print_tag == False:
print(f"[{datetime.now()}] 当前会话id: {data[1]}")
conversation_id_print_tag = True
# 更新 conversation_id
conversation_id = data[1]
# print(f"[{datetime.now()}] 收到会话id: {conversation_id}")
elif data == 'data: [DONE]\n\n':
# 接收到结束信号,退出循环
print(f"[{datetime.now()}] 会话结束-外层")
break
else:
yield data
finally:
stop_event.set()
fetcher_thread.join()
keep_alive_thread.join()
if conversation_id:
# print(f"[{datetime.now()}] 准备删除的会话id {conversation_id}")
delete_conversation(conversation_id, api_key) delete_conversation(conversation_id, api_key)