mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-15 07:31:35 +00:00
[feat] 支持 GPT-4 文件上传
This commit is contained in:
@@ -38,7 +38,9 @@
|
||||
|
||||
- [x] 支持 日志等级划分
|
||||
|
||||
- [ ] 支持 gpt-4-vision
|
||||
- [x] 支持 gpt-4-vision
|
||||
|
||||
- [ ] 支持 指定进程、线程数
|
||||
|
||||
- [ ] 优化 偶现的【0†source】引用bug
|
||||
|
||||
|
@@ -21,4 +21,11 @@ services:
|
||||
- ./log:/app/log
|
||||
- ./images:/app/images
|
||||
- ./gpts.json:/app/gpts.json
|
||||
|
||||
|
||||
redis:
|
||||
image: "redis:alpine"
|
||||
command: redis-server --appendonly yes
|
||||
ports:
|
||||
- "46379:6379"
|
||||
volumes:
|
||||
- ./redis-data:/data
|
286
main.py
286
main.py
@@ -14,6 +14,15 @@ import threading
|
||||
from queue import Queue, Empty
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import uuid
|
||||
import hashlib
|
||||
import requests
|
||||
import json
|
||||
import hashlib
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import base64
|
||||
|
||||
|
||||
NEED_LOG_TO_FILE = os.getenv('NEED_LOG_TO_FILE', 'true').lower() == 'true'
|
||||
@@ -141,10 +150,12 @@ KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
|
||||
# 添加环境变量配置
|
||||
API_PREFIX = os.getenv('API_PREFIX', '')
|
||||
|
||||
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||
|
||||
VERSION = '0.1.10'
|
||||
|
||||
VERSION = '0.2.0'
|
||||
# VERSION = 'test'
|
||||
UPDATE_INFO = '修复SSE输出收尾标志未正常输出的BUG'
|
||||
UPDATE_INFO = '支持 GPT-4 文件上传'
|
||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||
|
||||
with app.app_context():
|
||||
@@ -247,7 +258,169 @@ def get_token():
|
||||
|
||||
import os
|
||||
|
||||
def get_image_dimensions(file_content):
|
||||
with Image.open(BytesIO(file_content)) as img:
|
||||
return img.width, img.height
|
||||
|
||||
def determine_file_use_case(mime_type):
|
||||
multimodal_types = ["image/jpeg", "image/webp", "image/png", "image/gif"]
|
||||
my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/json", "text/javascript", "application/pdf",
|
||||
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
|
||||
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
|
||||
"text/x-ruby", "text/x-script.python"]
|
||||
|
||||
if mime_type in multimodal_types:
|
||||
return "multimodal"
|
||||
elif mime_type in my_files_types:
|
||||
return "my_files"
|
||||
else:
|
||||
return "ace_upload"
|
||||
|
||||
def upload_file(file_content, mime_type, api_key):
|
||||
logger.debug("文件上传开始")
|
||||
|
||||
width = None
|
||||
height = None
|
||||
if mime_type.startswith('image/'):
|
||||
try:
|
||||
width, height = get_image_dimensions(file_content)
|
||||
except Exception as e:
|
||||
logger.error(f"图片信息获取异常, 切换为text/plain: {e}")
|
||||
mime_type = 'text/plain'
|
||||
|
||||
# logger.debug(f"文件内容: {file_content}")
|
||||
file_size = len(file_content)
|
||||
logger.debug(f"文件大小: {file_size}")
|
||||
file_extension = get_file_extension(mime_type)
|
||||
logger.debug(f"文件扩展名: {file_extension}")
|
||||
sha256_hash = hashlib.sha256(file_content).hexdigest()
|
||||
logger.debug(f"sha256_hash: {sha256_hash}")
|
||||
file_name = f"{sha256_hash}{file_extension}"
|
||||
logger.debug(f"文件名: {file_name}")
|
||||
|
||||
|
||||
|
||||
logger.debug(f"Use Case: {determine_file_use_case(mime_type)}")
|
||||
|
||||
if determine_file_use_case(mime_type) == "ace_upload":
|
||||
mime_type = ''
|
||||
logger.debug(f"非已知文件类型,MINE置空")
|
||||
|
||||
# 第1步:调用/backend-api/files接口获取上传URL
|
||||
upload_api_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files"
|
||||
upload_request_payload = {
|
||||
"file_name": file_name,
|
||||
"file_size": file_size,
|
||||
"use_case": determine_file_use_case(mime_type)
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
upload_response = requests.post(upload_api_url, json=upload_request_payload, headers=headers)
|
||||
logger.debug(f"upload_response: {upload_response.text}")
|
||||
if upload_response.status_code != 200:
|
||||
raise Exception("Failed to get upload URL")
|
||||
|
||||
upload_data = upload_response.json()
|
||||
# 获取上传 URL 并替换域名
|
||||
parsed_url = urlparse(upload_data.get("upload_url"))
|
||||
new_netloc = PANDORA_UPLOAD_URL
|
||||
new_url = urlunparse(parsed_url._replace(netloc=new_netloc))
|
||||
upload_url = new_url
|
||||
logger.debug(f"upload_url: {upload_url}")
|
||||
file_id = upload_data.get("file_id")
|
||||
logger.debug(f"file_id: {file_id}")
|
||||
|
||||
# 第2步:上传文件
|
||||
put_headers = {
|
||||
'Content-Type': mime_type,
|
||||
'x-ms-blob-type': 'BlockBlob' # 添加这个头部
|
||||
}
|
||||
put_response = requests.put(upload_url, data=file_content, headers=put_headers)
|
||||
if put_response.status_code != 201:
|
||||
logger.debug(f"put_response: {put_response.text}")
|
||||
logger.debug(f"put_response status_code: {put_response.status_code}")
|
||||
raise Exception("Failed to upload file")
|
||||
|
||||
# 第3步:检测上传是否成功并检查响应
|
||||
check_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files/{file_id}/uploaded"
|
||||
check_response = requests.post(check_url, json={}, headers=headers)
|
||||
logger.debug(f"check_response: {check_response.text}")
|
||||
if check_response.status_code != 200:
|
||||
raise Exception("Failed to check file upload completion")
|
||||
|
||||
check_data = check_response.json()
|
||||
if check_data.get("status") != "success":
|
||||
raise Exception("File upload completion check not successful")
|
||||
|
||||
return {
|
||||
"file_id": file_id,
|
||||
"file_name": file_name,
|
||||
"size_bytes": file_size,
|
||||
"mimeType": mime_type,
|
||||
"width": width,
|
||||
"height": height
|
||||
}
|
||||
|
||||
def get_file_metadata(file_content, mime_type, api_key):
|
||||
sha256_hash = hashlib.sha256(file_content).hexdigest()
|
||||
logger.debug(f"sha256_hash: {sha256_hash}")
|
||||
|
||||
# 如果Redis中没有,上传文件并保存新数据
|
||||
new_file_data = upload_file(file_content, mime_type, api_key)
|
||||
mime_type = new_file_data.get('mimeType')
|
||||
# 为图片类型文件添加宽度和高度信息
|
||||
if mime_type.startswith('image/'):
|
||||
width, height = get_image_dimensions(file_content)
|
||||
new_file_data['width'] = width
|
||||
new_file_data['height'] = height
|
||||
|
||||
return new_file_data
|
||||
|
||||
|
||||
def get_file_extension(mime_type):
|
||||
# 基于 MIME 类型返回文件扩展名的映射表
|
||||
extension_mapping = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"text/x-php": ".php",
|
||||
"application/msword": ".doc",
|
||||
"text/x-c": ".c",
|
||||
"text/html": ".html",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/json": ".json",
|
||||
"text/javascript": ".js",
|
||||
"application/pdf": ".pdf",
|
||||
"text/x-java": ".java",
|
||||
"text/x-tex": ".tex",
|
||||
"text/x-typescript": ".ts",
|
||||
"text/x-sh": ".sh",
|
||||
"text/x-csharp": ".cs",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"text/x-c++": ".cpp",
|
||||
"application/x-latext": ".latex", # 这里可能需要根据实际情况调整
|
||||
"text/markdown": ".md",
|
||||
"text/plain": ".txt",
|
||||
"text/x-ruby": ".rb",
|
||||
"text/x-script.python": ".py",
|
||||
# 其他 MIME 类型和扩展名...
|
||||
}
|
||||
return extension_mapping.get(mime_type, "")
|
||||
|
||||
my_files_types = [
|
||||
"text/x-php", "application/msword", "text/x-c", "text/html",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/json", "text/javascript", "application/pdf",
|
||||
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
|
||||
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
|
||||
"text/x-ruby", "text/x-script.python"
|
||||
]
|
||||
|
||||
# 定义发送请求的函数
|
||||
def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
||||
@@ -255,18 +428,110 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
# 查找模型配置
|
||||
model_config = find_model_config(model)
|
||||
ori_model_name = ''
|
||||
if model_config:
|
||||
# 检查是否有 ori_name
|
||||
ori_model_name = model_config.get('ori_name', model)
|
||||
|
||||
formatted_messages = []
|
||||
# logger.debug(f"原始 messages: {messages}")
|
||||
for message in messages:
|
||||
message_id = str(uuid.uuid4())
|
||||
formatted_message = {
|
||||
"id": message_id,
|
||||
"author": {"role": message.get("role")},
|
||||
"content": {"content_type": "text", "parts": [message.get("content")]},
|
||||
"metadata": {}
|
||||
}
|
||||
formatted_messages.append(formatted_message)
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, list) and ori_model_name != 'gpt-3.5-turbo':
|
||||
logger.debug(f"gpt-vision 调用")
|
||||
new_parts = []
|
||||
attachments = []
|
||||
contains_image = False # 标记是否包含图片
|
||||
|
||||
for part in content:
|
||||
if isinstance(part, dict) and "type" in part:
|
||||
if part["type"] == "text":
|
||||
new_parts.append(part["text"])
|
||||
elif part["type"] == "image_url":
|
||||
# logger.debug(f"image_url: {part['image_url']}")
|
||||
file_url = part["image_url"]["url"]
|
||||
if file_url.startswith('data:'):
|
||||
# 处理 base64 编码的文件数据
|
||||
mime_type, base64_data = file_url.split(';')[0], file_url.split(',')[1]
|
||||
mime_type = mime_type.split(':')[1]
|
||||
try:
|
||||
file_content = base64.b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"类型为 {mime_type} 的 base64 编码数据解码失败: {e}")
|
||||
continue
|
||||
else:
|
||||
# 处理普通的文件URL
|
||||
try:
|
||||
file_response = requests.get(file_url)
|
||||
file_content = file_response.content
|
||||
mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip()
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件 {file_url} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"mime_type: {mime_type}")
|
||||
file_metadata = get_file_metadata(file_content, mime_type, api_key)
|
||||
|
||||
mime_type = file_metadata["mimeType"]
|
||||
logger.debug(f"处理后 mime_type: {mime_type}")
|
||||
|
||||
if mime_type.startswith('image/'):
|
||||
contains_image = True
|
||||
new_part = {
|
||||
"asset_pointer": f"file-service://{file_metadata['file_id']}",
|
||||
"size_bytes": file_metadata["size_bytes"],
|
||||
"width": file_metadata["width"],
|
||||
"height": file_metadata["height"]
|
||||
}
|
||||
new_parts.append(new_part)
|
||||
|
||||
attachment = {
|
||||
"name": file_metadata["file_name"],
|
||||
"id": file_metadata["file_id"],
|
||||
"mimeType": file_metadata["mimeType"],
|
||||
"size": file_metadata["size_bytes"] # 添加文件大小
|
||||
}
|
||||
|
||||
if mime_type.startswith('image/'):
|
||||
attachment.update({
|
||||
"width": file_metadata["width"],
|
||||
"height": file_metadata["height"]
|
||||
})
|
||||
elif mime_type in my_files_types:
|
||||
attachment.update({"fileTokenSize": len(file_metadata["file_name"])})
|
||||
|
||||
attachments.append(attachment)
|
||||
else:
|
||||
# 确保 part 是字符串
|
||||
text_part = str(part) if not isinstance(part, str) else part
|
||||
new_parts.append(text_part)
|
||||
|
||||
content_type = "multimodal_text" if contains_image else "text"
|
||||
formatted_message = {
|
||||
"id": message_id,
|
||||
"author": {"role": message.get("role")},
|
||||
"content": {"content_type": content_type, "parts": new_parts},
|
||||
"metadata": {"attachments": attachments}
|
||||
}
|
||||
formatted_messages.append(formatted_message)
|
||||
logger.critical(f"formatted_message: {formatted_message}")
|
||||
|
||||
else:
|
||||
# 处理单个文本消息的情况
|
||||
formatted_message = {
|
||||
"id": message_id,
|
||||
"author": {"role": message.get("role")},
|
||||
"content": {"content_type": "text", "parts": [content]},
|
||||
"metadata": {}
|
||||
}
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
# logger.debug(f"formatted_messages: {formatted_messages}")
|
||||
# return
|
||||
payload = {}
|
||||
|
||||
logger.info(f"model: {model}")
|
||||
@@ -327,6 +592,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
|
||||
payload = generate_gpts_payload(model, formatted_messages)
|
||||
if not payload:
|
||||
raise Exception('model is not accessible')
|
||||
logger.debug(f"payload: {payload}")
|
||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||
# print(response)
|
||||
return response
|
||||
@@ -775,7 +1041,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
|
||||
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
||||
while not stop_event.is_set():
|
||||
if time.time() - last_data_time[0] >=1:
|
||||
logger.debug(f"发送保活消息")
|
||||
# logger.debug(f"发送保活消息")
|
||||
# 当前时间戳
|
||||
timestamp = int(time.time())
|
||||
new_data = {
|
||||
|
Reference in New Issue
Block a user