mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-15 15:41:21 +00:00
[feat] 支持 GPT-4 文件上传
This commit is contained in:
@@ -38,7 +38,9 @@
|
|||||||
|
|
||||||
- [x] 支持 日志等级划分
|
- [x] 支持 日志等级划分
|
||||||
|
|
||||||
- [ ] 支持 gpt-4-vision
|
- [x] 支持 gpt-4-vision
|
||||||
|
|
||||||
|
- [ ] 支持 指定进程、线程数
|
||||||
|
|
||||||
- [ ] 优化 偶现的【0†source】引用bug
|
- [ ] 优化 偶现的【0†source】引用bug
|
||||||
|
|
||||||
|
@@ -21,4 +21,11 @@ services:
|
|||||||
- ./log:/app/log
|
- ./log:/app/log
|
||||||
- ./images:/app/images
|
- ./images:/app/images
|
||||||
- ./gpts.json:/app/gpts.json
|
- ./gpts.json:/app/gpts.json
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: "redis:alpine"
|
||||||
|
command: redis-server --appendonly yes
|
||||||
|
ports:
|
||||||
|
- "46379:6379"
|
||||||
|
volumes:
|
||||||
|
- ./redis-data:/data
|
286
main.py
286
main.py
@@ -14,6 +14,15 @@ import threading
|
|||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import TimedRotatingFileHandler
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
|
import uuid
|
||||||
|
import hashlib
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
NEED_LOG_TO_FILE = os.getenv('NEED_LOG_TO_FILE', 'true').lower() == 'true'
|
NEED_LOG_TO_FILE = os.getenv('NEED_LOG_TO_FILE', 'true').lower() == 'true'
|
||||||
@@ -141,10 +150,12 @@ KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
|
|||||||
# 添加环境变量配置
|
# 添加环境变量配置
|
||||||
API_PREFIX = os.getenv('API_PREFIX', '')
|
API_PREFIX = os.getenv('API_PREFIX', '')
|
||||||
|
|
||||||
|
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||||
|
|
||||||
VERSION = '0.1.10'
|
|
||||||
|
VERSION = '0.2.0'
|
||||||
# VERSION = 'test'
|
# VERSION = 'test'
|
||||||
UPDATE_INFO = '修复SSE输出收尾标志未正常输出的BUG'
|
UPDATE_INFO = '支持 GPT-4 文件上传'
|
||||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
@@ -247,7 +258,169 @@ def get_token():
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
def get_image_dimensions(file_content):
|
||||||
|
with Image.open(BytesIO(file_content)) as img:
|
||||||
|
return img.width, img.height
|
||||||
|
|
||||||
|
def determine_file_use_case(mime_type):
|
||||||
|
multimodal_types = ["image/jpeg", "image/webp", "image/png", "image/gif"]
|
||||||
|
my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"application/json", "text/javascript", "application/pdf",
|
||||||
|
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
|
||||||
|
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
|
||||||
|
"text/x-ruby", "text/x-script.python"]
|
||||||
|
|
||||||
|
if mime_type in multimodal_types:
|
||||||
|
return "multimodal"
|
||||||
|
elif mime_type in my_files_types:
|
||||||
|
return "my_files"
|
||||||
|
else:
|
||||||
|
return "ace_upload"
|
||||||
|
|
||||||
|
def upload_file(file_content, mime_type, api_key):
|
||||||
|
logger.debug("文件上传开始")
|
||||||
|
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
if mime_type.startswith('image/'):
|
||||||
|
try:
|
||||||
|
width, height = get_image_dimensions(file_content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片信息获取异常, 切换为text/plain: {e}")
|
||||||
|
mime_type = 'text/plain'
|
||||||
|
|
||||||
|
# logger.debug(f"文件内容: {file_content}")
|
||||||
|
file_size = len(file_content)
|
||||||
|
logger.debug(f"文件大小: {file_size}")
|
||||||
|
file_extension = get_file_extension(mime_type)
|
||||||
|
logger.debug(f"文件扩展名: {file_extension}")
|
||||||
|
sha256_hash = hashlib.sha256(file_content).hexdigest()
|
||||||
|
logger.debug(f"sha256_hash: {sha256_hash}")
|
||||||
|
file_name = f"{sha256_hash}{file_extension}"
|
||||||
|
logger.debug(f"文件名: {file_name}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
logger.debug(f"Use Case: {determine_file_use_case(mime_type)}")
|
||||||
|
|
||||||
|
if determine_file_use_case(mime_type) == "ace_upload":
|
||||||
|
mime_type = ''
|
||||||
|
logger.debug(f"非已知文件类型,MINE置空")
|
||||||
|
|
||||||
|
# 第1步:调用/backend-api/files接口获取上传URL
|
||||||
|
upload_api_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files"
|
||||||
|
upload_request_payload = {
|
||||||
|
"file_name": file_name,
|
||||||
|
"file_size": file_size,
|
||||||
|
"use_case": determine_file_use_case(mime_type)
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
upload_response = requests.post(upload_api_url, json=upload_request_payload, headers=headers)
|
||||||
|
logger.debug(f"upload_response: {upload_response.text}")
|
||||||
|
if upload_response.status_code != 200:
|
||||||
|
raise Exception("Failed to get upload URL")
|
||||||
|
|
||||||
|
upload_data = upload_response.json()
|
||||||
|
# 获取上传 URL 并替换域名
|
||||||
|
parsed_url = urlparse(upload_data.get("upload_url"))
|
||||||
|
new_netloc = PANDORA_UPLOAD_URL
|
||||||
|
new_url = urlunparse(parsed_url._replace(netloc=new_netloc))
|
||||||
|
upload_url = new_url
|
||||||
|
logger.debug(f"upload_url: {upload_url}")
|
||||||
|
file_id = upload_data.get("file_id")
|
||||||
|
logger.debug(f"file_id: {file_id}")
|
||||||
|
|
||||||
|
# 第2步:上传文件
|
||||||
|
put_headers = {
|
||||||
|
'Content-Type': mime_type,
|
||||||
|
'x-ms-blob-type': 'BlockBlob' # 添加这个头部
|
||||||
|
}
|
||||||
|
put_response = requests.put(upload_url, data=file_content, headers=put_headers)
|
||||||
|
if put_response.status_code != 201:
|
||||||
|
logger.debug(f"put_response: {put_response.text}")
|
||||||
|
logger.debug(f"put_response status_code: {put_response.status_code}")
|
||||||
|
raise Exception("Failed to upload file")
|
||||||
|
|
||||||
|
# 第3步:检测上传是否成功并检查响应
|
||||||
|
check_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files/{file_id}/uploaded"
|
||||||
|
check_response = requests.post(check_url, json={}, headers=headers)
|
||||||
|
logger.debug(f"check_response: {check_response.text}")
|
||||||
|
if check_response.status_code != 200:
|
||||||
|
raise Exception("Failed to check file upload completion")
|
||||||
|
|
||||||
|
check_data = check_response.json()
|
||||||
|
if check_data.get("status") != "success":
|
||||||
|
raise Exception("File upload completion check not successful")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_name": file_name,
|
||||||
|
"size_bytes": file_size,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"width": width,
|
||||||
|
"height": height
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_file_metadata(file_content, mime_type, api_key):
|
||||||
|
sha256_hash = hashlib.sha256(file_content).hexdigest()
|
||||||
|
logger.debug(f"sha256_hash: {sha256_hash}")
|
||||||
|
|
||||||
|
# 如果Redis中没有,上传文件并保存新数据
|
||||||
|
new_file_data = upload_file(file_content, mime_type, api_key)
|
||||||
|
mime_type = new_file_data.get('mimeType')
|
||||||
|
# 为图片类型文件添加宽度和高度信息
|
||||||
|
if mime_type.startswith('image/'):
|
||||||
|
width, height = get_image_dimensions(file_content)
|
||||||
|
new_file_data['width'] = width
|
||||||
|
new_file_data['height'] = height
|
||||||
|
|
||||||
|
return new_file_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension(mime_type):
|
||||||
|
# 基于 MIME 类型返回文件扩展名的映射表
|
||||||
|
extension_mapping = {
|
||||||
|
"image/jpeg": ".jpg",
|
||||||
|
"image/png": ".png",
|
||||||
|
"image/gif": ".gif",
|
||||||
|
"image/webp": ".webp",
|
||||||
|
"text/x-php": ".php",
|
||||||
|
"application/msword": ".doc",
|
||||||
|
"text/x-c": ".c",
|
||||||
|
"text/html": ".html",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/json": ".json",
|
||||||
|
"text/javascript": ".js",
|
||||||
|
"application/pdf": ".pdf",
|
||||||
|
"text/x-java": ".java",
|
||||||
|
"text/x-tex": ".tex",
|
||||||
|
"text/x-typescript": ".ts",
|
||||||
|
"text/x-sh": ".sh",
|
||||||
|
"text/x-csharp": ".cs",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||||
|
"text/x-c++": ".cpp",
|
||||||
|
"application/x-latext": ".latex", # 这里可能需要根据实际情况调整
|
||||||
|
"text/markdown": ".md",
|
||||||
|
"text/plain": ".txt",
|
||||||
|
"text/x-ruby": ".rb",
|
||||||
|
"text/x-script.python": ".py",
|
||||||
|
# 其他 MIME 类型和扩展名...
|
||||||
|
}
|
||||||
|
return extension_mapping.get(mime_type, "")
|
||||||
|
|
||||||
|
my_files_types = [
|
||||||
|
"text/x-php", "application/msword", "text/x-c", "text/html",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"application/json", "text/javascript", "application/pdf",
|
||||||
|
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
|
||||||
|
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
|
||||||
|
"text/x-ruby", "text/x-script.python"
|
||||||
|
]
|
||||||
|
|
||||||
# 定义发送请求的函数
|
# 定义发送请求的函数
|
||||||
def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
||||||
@@ -255,18 +428,110 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
|||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_key}"
|
"Authorization": f"Bearer {api_key}"
|
||||||
}
|
}
|
||||||
|
# 查找模型配置
|
||||||
|
model_config = find_model_config(model)
|
||||||
|
ori_model_name = ''
|
||||||
|
if model_config:
|
||||||
|
# 检查是否有 ori_name
|
||||||
|
ori_model_name = model_config.get('ori_name', model)
|
||||||
|
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
# logger.debug(f"原始 messages: {messages}")
|
||||||
for message in messages:
|
for message in messages:
|
||||||
message_id = str(uuid.uuid4())
|
message_id = str(uuid.uuid4())
|
||||||
formatted_message = {
|
content = message.get("content")
|
||||||
"id": message_id,
|
|
||||||
"author": {"role": message.get("role")},
|
|
||||||
"content": {"content_type": "text", "parts": [message.get("content")]},
|
|
||||||
"metadata": {}
|
|
||||||
}
|
|
||||||
formatted_messages.append(formatted_message)
|
|
||||||
|
|
||||||
|
if isinstance(content, list) and ori_model_name != 'gpt-3.5-turbo':
|
||||||
|
logger.debug(f"gpt-vision 调用")
|
||||||
|
new_parts = []
|
||||||
|
attachments = []
|
||||||
|
contains_image = False # 标记是否包含图片
|
||||||
|
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and "type" in part:
|
||||||
|
if part["type"] == "text":
|
||||||
|
new_parts.append(part["text"])
|
||||||
|
elif part["type"] == "image_url":
|
||||||
|
# logger.debug(f"image_url: {part['image_url']}")
|
||||||
|
file_url = part["image_url"]["url"]
|
||||||
|
if file_url.startswith('data:'):
|
||||||
|
# 处理 base64 编码的文件数据
|
||||||
|
mime_type, base64_data = file_url.split(';')[0], file_url.split(',')[1]
|
||||||
|
mime_type = mime_type.split(':')[1]
|
||||||
|
try:
|
||||||
|
file_content = base64.b64decode(base64_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"类型为 {mime_type} 的 base64 编码数据解码失败: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 处理普通的文件URL
|
||||||
|
try:
|
||||||
|
file_response = requests.get(file_url)
|
||||||
|
file_content = file_response.content
|
||||||
|
mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文件 {file_url} 失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"mime_type: {mime_type}")
|
||||||
|
file_metadata = get_file_metadata(file_content, mime_type, api_key)
|
||||||
|
|
||||||
|
mime_type = file_metadata["mimeType"]
|
||||||
|
logger.debug(f"处理后 mime_type: {mime_type}")
|
||||||
|
|
||||||
|
if mime_type.startswith('image/'):
|
||||||
|
contains_image = True
|
||||||
|
new_part = {
|
||||||
|
"asset_pointer": f"file-service://{file_metadata['file_id']}",
|
||||||
|
"size_bytes": file_metadata["size_bytes"],
|
||||||
|
"width": file_metadata["width"],
|
||||||
|
"height": file_metadata["height"]
|
||||||
|
}
|
||||||
|
new_parts.append(new_part)
|
||||||
|
|
||||||
|
attachment = {
|
||||||
|
"name": file_metadata["file_name"],
|
||||||
|
"id": file_metadata["file_id"],
|
||||||
|
"mimeType": file_metadata["mimeType"],
|
||||||
|
"size": file_metadata["size_bytes"] # 添加文件大小
|
||||||
|
}
|
||||||
|
|
||||||
|
if mime_type.startswith('image/'):
|
||||||
|
attachment.update({
|
||||||
|
"width": file_metadata["width"],
|
||||||
|
"height": file_metadata["height"]
|
||||||
|
})
|
||||||
|
elif mime_type in my_files_types:
|
||||||
|
attachment.update({"fileTokenSize": len(file_metadata["file_name"])})
|
||||||
|
|
||||||
|
attachments.append(attachment)
|
||||||
|
else:
|
||||||
|
# 确保 part 是字符串
|
||||||
|
text_part = str(part) if not isinstance(part, str) else part
|
||||||
|
new_parts.append(text_part)
|
||||||
|
|
||||||
|
content_type = "multimodal_text" if contains_image else "text"
|
||||||
|
formatted_message = {
|
||||||
|
"id": message_id,
|
||||||
|
"author": {"role": message.get("role")},
|
||||||
|
"content": {"content_type": content_type, "parts": new_parts},
|
||||||
|
"metadata": {"attachments": attachments}
|
||||||
|
}
|
||||||
|
formatted_messages.append(formatted_message)
|
||||||
|
logger.critical(f"formatted_message: {formatted_message}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 处理单个文本消息的情况
|
||||||
|
formatted_message = {
|
||||||
|
"id": message_id,
|
||||||
|
"author": {"role": message.get("role")},
|
||||||
|
"content": {"content_type": "text", "parts": [content]},
|
||||||
|
"metadata": {}
|
||||||
|
}
|
||||||
|
formatted_messages.append(formatted_message)
|
||||||
|
|
||||||
|
# logger.debug(f"formatted_messages: {formatted_messages}")
|
||||||
|
# return
|
||||||
payload = {}
|
payload = {}
|
||||||
|
|
||||||
logger.info(f"model: {model}")
|
logger.info(f"model: {model}")
|
||||||
@@ -327,6 +592,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
|||||||
payload = generate_gpts_payload(model, formatted_messages)
|
payload = generate_gpts_payload(model, formatted_messages)
|
||||||
if not payload:
|
if not payload:
|
||||||
raise Exception('model is not accessible')
|
raise Exception('model is not accessible')
|
||||||
|
logger.debug(f"payload: {payload}")
|
||||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||||
# print(response)
|
# print(response)
|
||||||
return response
|
return response
|
||||||
@@ -775,7 +1041,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
|
|||||||
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
if time.time() - last_data_time[0] >=1:
|
if time.time() - last_data_time[0] >=1:
|
||||||
logger.debug(f"发送保活消息")
|
# logger.debug(f"发送保活消息")
|
||||||
# 当前时间戳
|
# 当前时间戳
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
new_data = {
|
new_data = {
|
||||||
|
Reference in New Issue
Block a user