mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-19 01:43:53 +00:00
[Fix] 重构代码,优化输出稳定性,增加接口保活输出
This commit is contained in:
@@ -20,10 +20,12 @@
|
|||||||
|
|
||||||
- [ ] 支持 日志等级划分
|
- [ ] 支持 日志等级划分
|
||||||
|
|
||||||
- [ ] 支持 接口保活
|
- [x] 支持 接口保活
|
||||||
|
|
||||||
- [ ] 支持 自定义接口前缀
|
- [ ] 支持 自定义接口前缀
|
||||||
|
|
||||||
|
- [ ] 优化 偶现的【0†source】引用bug
|
||||||
|
|
||||||
## 注意
|
## 注意
|
||||||
|
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
|
195
main.py
195
main.py
@@ -10,6 +10,8 @@ from datetime import datetime
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
|
from queue import Queue, Empty
|
||||||
|
|
||||||
def generate_unique_id(prefix):
|
def generate_unique_id(prefix):
|
||||||
# 生成一个随机的 UUID
|
# 生成一个随机的 UUID
|
||||||
@@ -103,9 +105,9 @@ PROXY_API_PREFIX = os.getenv('PROXY_API_PREFIX', '')
|
|||||||
UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '')
|
UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '')
|
||||||
KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
|
KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
|
||||||
|
|
||||||
VERSION = '0.1.6'
|
VERSION = '0.1.7'
|
||||||
# VERSION = 'test'
|
# VERSION = 'test'
|
||||||
UPDATE_INFO = '对于自带的三个模型,支持多个名字映射到同一个模型'
|
UPDATE_INFO = '重构代码,优化输出稳定性'
|
||||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
@@ -382,35 +384,9 @@ def replace_complete_citation(text, citations):
|
|||||||
|
|
||||||
return replaced_text, remaining_text, is_potential_citation
|
return replaced_text, remaining_text, is_potential_citation
|
||||||
|
|
||||||
# 定义 Flask 路由
|
def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model):
|
||||||
@app.route('/v1/chat/completions', methods=['POST'])
|
|
||||||
def chat_completions():
|
|
||||||
print(f"[{datetime.now()}] New Request")
|
|
||||||
data = request.json
|
|
||||||
messages = data.get('messages')
|
|
||||||
model = data.get('model')
|
|
||||||
accessible_model_list = get_accessible_model_list()
|
|
||||||
if model not in accessible_model_list:
|
|
||||||
return jsonify({"error": "model is not accessible"}), 401
|
|
||||||
|
|
||||||
stream = data.get('stream', False)
|
|
||||||
|
|
||||||
auth_header = request.headers.get('Authorization')
|
|
||||||
if not auth_header or not auth_header.startswith('Bearer '):
|
|
||||||
return jsonify({"error": "Authorization header is missing or invalid"}), 401
|
|
||||||
api_key = auth_header.split(' ')[1]
|
|
||||||
print(f"api_key: {api_key}")
|
|
||||||
|
|
||||||
|
|
||||||
upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model)
|
|
||||||
|
|
||||||
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
|
|
||||||
all_new_text = ""
|
all_new_text = ""
|
||||||
|
|
||||||
# 处理流式响应
|
|
||||||
def generate():
|
|
||||||
nonlocal all_new_text # 引用外部变量
|
|
||||||
chat_message_id = generate_unique_id("chatcmpl")
|
|
||||||
# 当前时间戳
|
# 当前时间戳
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
|
|
||||||
@@ -453,8 +429,8 @@ def chat_completions():
|
|||||||
content = message.get("content", {})
|
content = message.get("content", {})
|
||||||
role = message.get("author", {}).get("role")
|
role = message.get("author", {}).get("role")
|
||||||
content_type = content.get("content_type")
|
content_type = content.get("content_type")
|
||||||
print(f"content_type: {content_type}")
|
# print(f"content_type: {content_type}")
|
||||||
print(f"last_content_type: {last_content_type}")
|
# print(f"last_content_type: {last_content_type}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
citations = []
|
citations = []
|
||||||
@@ -469,7 +445,9 @@ def chat_completions():
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
conversation_id = data_json.get("conversation_id")
|
conversation_id = data_json.get("conversation_id")
|
||||||
print(f"conversation_id: {conversation_id}")
|
# print(f"conversation_id: {conversation_id}")
|
||||||
|
if conversation_id:
|
||||||
|
data_queue.put(('conversation_id', conversation_id))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
# 只获取新的部分
|
# 只获取新的部分
|
||||||
@@ -557,13 +535,13 @@ def chat_completions():
|
|||||||
if "\u3010" in new_text and not citation_accumulating:
|
if "\u3010" in new_text and not citation_accumulating:
|
||||||
citation_accumulating = True
|
citation_accumulating = True
|
||||||
citation_buffer = citation_buffer + new_text
|
citation_buffer = citation_buffer + new_text
|
||||||
print(f"开始积累引用: {citation_buffer}")
|
# print(f"开始积累引用: {citation_buffer}")
|
||||||
elif citation_accumulating:
|
elif citation_accumulating:
|
||||||
citation_buffer += new_text
|
citation_buffer += new_text
|
||||||
print(f"积累引用: {citation_buffer}")
|
# print(f"积累引用: {citation_buffer}")
|
||||||
if citation_accumulating:
|
if citation_accumulating:
|
||||||
if is_valid_citation_format(citation_buffer):
|
if is_valid_citation_format(citation_buffer):
|
||||||
print(f"合法格式: {citation_buffer}")
|
# print(f"合法格式: {citation_buffer}")
|
||||||
# 继续积累
|
# 继续积累
|
||||||
if is_complete_citation_format(citation_buffer):
|
if is_complete_citation_format(citation_buffer):
|
||||||
|
|
||||||
@@ -577,12 +555,12 @@ def chat_completions():
|
|||||||
else:
|
else:
|
||||||
citation_accumulating = False
|
citation_accumulating = False
|
||||||
citation_buffer = ""
|
citation_buffer = ""
|
||||||
print(f"替换完整的引用格式: {new_text}")
|
# print(f"替换完整的引用格式: {new_text}")
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# 不是合法格式,放弃积累并响应
|
# 不是合法格式,放弃积累并响应
|
||||||
print(f"不合法格式: {citation_buffer}")
|
# print(f"不合法格式: {citation_buffer}")
|
||||||
new_text = citation_buffer
|
new_text = citation_buffer
|
||||||
citation_accumulating = False
|
citation_accumulating = False
|
||||||
citation_buffer = ""
|
citation_buffer = ""
|
||||||
@@ -627,12 +605,12 @@ def chat_completions():
|
|||||||
if content_type != None:
|
if content_type != None:
|
||||||
last_content_type = content_type if role != "user" else last_content_type
|
last_content_type = content_type if role != "user" else last_content_type
|
||||||
|
|
||||||
|
model_slug = message.get("metadata", {}).get("model_slug") or model
|
||||||
new_data = {
|
new_data = {
|
||||||
"id": chat_message_id,
|
"id": chat_message_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": timestamp,
|
"created": timestamp,
|
||||||
"model": message.get("metadata", {}).get("model_slug"),
|
"model": model_slug,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
@@ -649,13 +627,22 @@ def chat_completions():
|
|||||||
# print(f"[{datetime.now()}] 发送数据: {tmp}")
|
# print(f"[{datetime.now()}] 发送数据: {tmp}")
|
||||||
# 累积 new_text
|
# 累积 new_text
|
||||||
all_new_text += new_text
|
all_new_text += new_text
|
||||||
yield 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
|
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# print("JSON 解析错误")
|
# print("JSON 解析错误")
|
||||||
print(f"[{datetime.now()}] 发送数据: {complete_data}")
|
print(f"[{datetime.now()}] 发送数据: {complete_data}")
|
||||||
if complete_data == 'data: [DONE]\n\n':
|
if complete_data == 'data: [DONE]\n\n':
|
||||||
print(f"[{datetime.now()}] 会话结束")
|
print(f"[{datetime.now()}] 会话结束")
|
||||||
yield complete_data
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
if citation_buffer != "":
|
if citation_buffer != "":
|
||||||
new_data = {
|
new_data = {
|
||||||
"id": chat_message_id,
|
"id": chat_message_id,
|
||||||
@@ -676,10 +663,12 @@ def chat_completions():
|
|||||||
# print(f"[{datetime.now()}] 发送数据: {tmp}")
|
# print(f"[{datetime.now()}] 发送数据: {tmp}")
|
||||||
# 累积 new_text
|
# 累积 new_text
|
||||||
all_new_text += citation_buffer
|
all_new_text += citation_buffer
|
||||||
yield 'data: ' + json.dumps(new_data) + '\n\n'
|
q_data = 'data: ' + json.dumps(new_data) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
if buffer:
|
if buffer:
|
||||||
# print(f"[{datetime.now()}] 最后的数据: {buffer}")
|
# print(f"[{datetime.now()}] 最后的数据: {buffer}")
|
||||||
delete_conversation(conversation_id, api_key)
|
# delete_conversation(conversation_id, api_key)
|
||||||
try:
|
try:
|
||||||
buffer_json = json.loads(buffer)
|
buffer_json = json.loads(buffer)
|
||||||
error_message = buffer_json.get("detail", {}).get("message", "未知错误")
|
error_message = buffer_json.get("detail", {}).get("message", "未知错误")
|
||||||
@@ -702,12 +691,128 @@ def chat_completions():
|
|||||||
print(f"[{datetime.now()}] 发送最后的数据: {tmp}")
|
print(f"[{datetime.now()}] 发送最后的数据: {tmp}")
|
||||||
# 累积 new_text
|
# 累积 new_text
|
||||||
all_new_text += ''.join("```\n" + error_message + "\n```")
|
all_new_text += ''.join("```\n" + error_message + "\n```")
|
||||||
yield 'data: ' + json.dumps(error_data) + '\n\n'
|
q_data = 'data: ' + json.dumps(error_data) + '\n\n'
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
print(f"[{datetime.now()}] 会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
except:
|
except:
|
||||||
# print("JSON 解析错误")
|
# print("JSON 解析错误")
|
||||||
print(f"[{datetime.now()}] 发送最后的数据: {buffer}")
|
print(f"[{datetime.now()}] 发送最后的数据: {buffer}")
|
||||||
yield buffer
|
q_data = buffer
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
complete_data = 'data: [DONE]\n\n'
|
||||||
|
print(f"[{datetime.now()}] 会话结束")
|
||||||
|
q_data = complete_data
|
||||||
|
data_queue.put(('all_new_text', all_new_text))
|
||||||
|
data_queue.put(q_data)
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
|
||||||
|
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
||||||
|
while not stop_event.is_set():
|
||||||
|
if time.time() - last_data_time[0] >=1:
|
||||||
|
print(f"[{datetime.now()}] 发送保活消息")
|
||||||
|
# 当前时间戳
|
||||||
|
timestamp = int(time.time())
|
||||||
|
new_data = {
|
||||||
|
"id": chat_message_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": timestamp,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": ''
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
queue.put(f'data: {json.dumps(new_data)}\n\n') # 发送保活消息
|
||||||
|
last_data_time[0] = time.time()
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
# 定义 Flask 路由
|
||||||
|
@app.route('/v1/chat/completions', methods=['POST'])
|
||||||
|
def chat_completions():
|
||||||
|
print(f"[{datetime.now()}] New Request")
|
||||||
|
data = request.json
|
||||||
|
messages = data.get('messages')
|
||||||
|
model = data.get('model')
|
||||||
|
accessible_model_list = get_accessible_model_list()
|
||||||
|
if model not in accessible_model_list:
|
||||||
|
return jsonify({"error": "model is not accessible"}), 401
|
||||||
|
|
||||||
|
stream = data.get('stream', False)
|
||||||
|
|
||||||
|
auth_header = request.headers.get('Authorization')
|
||||||
|
if not auth_header or not auth_header.startswith('Bearer '):
|
||||||
|
return jsonify({"error": "Authorization header is missing or invalid"}), 401
|
||||||
|
api_key = auth_header.split(' ')[1]
|
||||||
|
print(f"api_key: {api_key}")
|
||||||
|
|
||||||
|
|
||||||
|
upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model)
|
||||||
|
|
||||||
|
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
|
||||||
|
all_new_text = ""
|
||||||
|
|
||||||
|
# 处理流式响应
|
||||||
|
def generate():
|
||||||
|
nonlocal all_new_text # 引用外部变量
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
last_data_time = [time.time()]
|
||||||
|
chat_message_id = generate_unique_id("chatcmpl")
|
||||||
|
|
||||||
|
conversation_id_print_tag = False
|
||||||
|
|
||||||
|
conversation_id = ''
|
||||||
|
|
||||||
|
# 启动数据处理线程
|
||||||
|
fetcher_thread = threading.Thread(target=data_fetcher, args=(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model))
|
||||||
|
fetcher_thread.start()
|
||||||
|
|
||||||
|
# 启动保活线程
|
||||||
|
keep_alive_thread = threading.Thread(target=keep_alive, args=(last_data_time, stop_event, data_queue, model, chat_message_id))
|
||||||
|
keep_alive_thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = data_queue.get()
|
||||||
|
if isinstance(data, tuple) and data[0] == 'all_new_text':
|
||||||
|
# 更新 all_new_text
|
||||||
|
print(f"[{datetime.now()}] 完整消息: {data[1]}")
|
||||||
|
all_new_text += data[1]
|
||||||
|
elif isinstance(data, tuple) and data[0] == 'conversation_id':
|
||||||
|
if conversation_id_print_tag == False:
|
||||||
|
print(f"[{datetime.now()}] 当前会话id: {data[1]}")
|
||||||
|
conversation_id_print_tag = True
|
||||||
|
# 更新 conversation_id
|
||||||
|
conversation_id = data[1]
|
||||||
|
# print(f"[{datetime.now()}] 收到会话id: {conversation_id}")
|
||||||
|
elif data == 'data: [DONE]\n\n':
|
||||||
|
# 接收到结束信号,退出循环
|
||||||
|
print(f"[{datetime.now()}] 会话结束-外层")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
fetcher_thread.join()
|
||||||
|
keep_alive_thread.join()
|
||||||
|
|
||||||
|
if conversation_id:
|
||||||
|
# print(f"[{datetime.now()}] 准备删除的会话id: {conversation_id}")
|
||||||
delete_conversation(conversation_id, api_key)
|
delete_conversation(conversation_id, api_key)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user