[feat] 支持 GPT-4 文件上传

This commit is contained in:
Wizerd
2023-12-20 16:47:07 +08:00
parent 4add67b342
commit f595b9dcd1
3 changed files with 287 additions and 12 deletions

View File

@@ -38,7 +38,9 @@
- [x] 支持 日志等级划分
- [ ] 支持 gpt-4-vision
- [x] 支持 gpt-4-vision
- [ ] 支持 指定进程、线程数
- [ ] 优化 偶现的【0†source】引用bug

View File

@@ -21,4 +21,11 @@ services:
- ./log:/app/log
- ./images:/app/images
- ./gpts.json:/app/gpts.json
redis:
image: "redis:alpine"
command: redis-server --appendonly yes
ports:
- "46379:6379"
volumes:
- ./redis-data:/data

286
main.py
View File

@@ -14,6 +14,15 @@ import threading
from queue import Queue, Empty
import logging
from logging.handlers import TimedRotatingFileHandler
import uuid
import hashlib
import requests
import json
import hashlib
from PIL import Image
from io import BytesIO
from urllib.parse import urlparse, urlunparse
import base64
NEED_LOG_TO_FILE = os.getenv('NEED_LOG_TO_FILE', 'true').lower() == 'true'
@@ -141,10 +150,12 @@ KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
# 添加环境变量配置
API_PREFIX = os.getenv('API_PREFIX', '')
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.1.10'
VERSION = '0.2.0'
# VERSION = 'test'
UPDATE_INFO = '修复SSE输出收尾标志未正常输出的BUG'
UPDATE_INFO = '支持 GPT-4 文件上传'
# UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context():
@@ -247,7 +258,169 @@ def get_token():
import os
def get_image_dimensions(file_content):
with Image.open(BytesIO(file_content)) as img:
return img.width, img.height
def determine_file_use_case(mime_type):
multimodal_types = ["image/jpeg", "image/webp", "image/png", "image/gif"]
my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/json", "text/javascript", "application/pdf",
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
"text/x-ruby", "text/x-script.python"]
if mime_type in multimodal_types:
return "multimodal"
elif mime_type in my_files_types:
return "my_files"
else:
return "ace_upload"
def upload_file(file_content, mime_type, api_key):
logger.debug("文件上传开始")
width = None
height = None
if mime_type.startswith('image/'):
try:
width, height = get_image_dimensions(file_content)
except Exception as e:
logger.error(f"图片信息获取异常, 切换为text/plain {e}")
mime_type = 'text/plain'
# logger.debug(f"文件内容: {file_content}")
file_size = len(file_content)
logger.debug(f"文件大小: {file_size}")
file_extension = get_file_extension(mime_type)
logger.debug(f"文件扩展名: {file_extension}")
sha256_hash = hashlib.sha256(file_content).hexdigest()
logger.debug(f"sha256_hash: {sha256_hash}")
file_name = f"{sha256_hash}{file_extension}"
logger.debug(f"文件名: {file_name}")
logger.debug(f"Use Case: {determine_file_use_case(mime_type)}")
if determine_file_use_case(mime_type) == "ace_upload":
mime_type = ''
logger.debug(f"非已知文件类型MINE置空")
# 第1步调用/backend-api/files接口获取上传URL
upload_api_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files"
upload_request_payload = {
"file_name": file_name,
"file_size": file_size,
"use_case": determine_file_use_case(mime_type)
}
headers = {
"Authorization": f"Bearer {api_key}"
}
upload_response = requests.post(upload_api_url, json=upload_request_payload, headers=headers)
logger.debug(f"upload_response: {upload_response.text}")
if upload_response.status_code != 200:
raise Exception("Failed to get upload URL")
upload_data = upload_response.json()
# 获取上传 URL 并替换域名
parsed_url = urlparse(upload_data.get("upload_url"))
new_netloc = PANDORA_UPLOAD_URL
new_url = urlunparse(parsed_url._replace(netloc=new_netloc))
upload_url = new_url
logger.debug(f"upload_url: {upload_url}")
file_id = upload_data.get("file_id")
logger.debug(f"file_id: {file_id}")
# 第2步上传文件
put_headers = {
'Content-Type': mime_type,
'x-ms-blob-type': 'BlockBlob' # 添加这个头部
}
put_response = requests.put(upload_url, data=file_content, headers=put_headers)
if put_response.status_code != 201:
logger.debug(f"put_response: {put_response.text}")
logger.debug(f"put_response status_code: {put_response.status_code}")
raise Exception("Failed to upload file")
# 第3步检测上传是否成功并检查响应
check_url = f"{BASE_URL}/{PROXY_API_PREFIX}/backend-api/files/{file_id}/uploaded"
check_response = requests.post(check_url, json={}, headers=headers)
logger.debug(f"check_response: {check_response.text}")
if check_response.status_code != 200:
raise Exception("Failed to check file upload completion")
check_data = check_response.json()
if check_data.get("status") != "success":
raise Exception("File upload completion check not successful")
return {
"file_id": file_id,
"file_name": file_name,
"size_bytes": file_size,
"mimeType": mime_type,
"width": width,
"height": height
}
def get_file_metadata(file_content, mime_type, api_key):
sha256_hash = hashlib.sha256(file_content).hexdigest()
logger.debug(f"sha256_hash: {sha256_hash}")
# 如果Redis中没有上传文件并保存新数据
new_file_data = upload_file(file_content, mime_type, api_key)
mime_type = new_file_data.get('mimeType')
# 为图片类型文件添加宽度和高度信息
if mime_type.startswith('image/'):
width, height = get_image_dimensions(file_content)
new_file_data['width'] = width
new_file_data['height'] = height
return new_file_data
def get_file_extension(mime_type):
# 基于 MIME 类型返回文件扩展名的映射表
extension_mapping = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"image/webp": ".webp",
"text/x-php": ".php",
"application/msword": ".doc",
"text/x-c": ".c",
"text/html": ".html",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/json": ".json",
"text/javascript": ".js",
"application/pdf": ".pdf",
"text/x-java": ".java",
"text/x-tex": ".tex",
"text/x-typescript": ".ts",
"text/x-sh": ".sh",
"text/x-csharp": ".cs",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"text/x-c++": ".cpp",
"application/x-latext": ".latex", # 这里可能需要根据实际情况调整
"text/markdown": ".md",
"text/plain": ".txt",
"text/x-ruby": ".rb",
"text/x-script.python": ".py",
# 其他 MIME 类型和扩展名...
}
return extension_mapping.get(mime_type, "")
my_files_types = [
"text/x-php", "application/msword", "text/x-c", "text/html",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/json", "text/javascript", "application/pdf",
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
"text/x-ruby", "text/x-script.python"
]
# 定义发送请求的函数
def send_text_prompt_and_get_response(messages, api_key, stream, model):
@@ -255,18 +428,110 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
headers = {
"Authorization": f"Bearer {api_key}"
}
# 查找模型配置
model_config = find_model_config(model)
ori_model_name = ''
if model_config:
# 检查是否有 ori_name
ori_model_name = model_config.get('ori_name', model)
formatted_messages = []
# logger.debug(f"原始 messages: {messages}")
for message in messages:
message_id = str(uuid.uuid4())
formatted_message = {
"id": message_id,
"author": {"role": message.get("role")},
"content": {"content_type": "text", "parts": [message.get("content")]},
"metadata": {}
}
formatted_messages.append(formatted_message)
content = message.get("content")
if isinstance(content, list) and ori_model_name != 'gpt-3.5-turbo':
logger.debug(f"gpt-vision 调用")
new_parts = []
attachments = []
contains_image = False # 标记是否包含图片
for part in content:
if isinstance(part, dict) and "type" in part:
if part["type"] == "text":
new_parts.append(part["text"])
elif part["type"] == "image_url":
# logger.debug(f"image_url: {part['image_url']}")
file_url = part["image_url"]["url"]
if file_url.startswith('data:'):
# 处理 base64 编码的文件数据
mime_type, base64_data = file_url.split(';')[0], file_url.split(',')[1]
mime_type = mime_type.split(':')[1]
try:
file_content = base64.b64decode(base64_data)
except Exception as e:
logger.error(f"类型为 {mime_type} 的 base64 编码数据解码失败: {e}")
continue
else:
# 处理普通的文件URL
try:
file_response = requests.get(file_url)
file_content = file_response.content
mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip()
except Exception as e:
logger.error(f"获取文件 {file_url} 失败: {e}")
continue
logger.debug(f"mime_type: {mime_type}")
file_metadata = get_file_metadata(file_content, mime_type, api_key)
mime_type = file_metadata["mimeType"]
logger.debug(f"处理后 mime_type: {mime_type}")
if mime_type.startswith('image/'):
contains_image = True
new_part = {
"asset_pointer": f"file-service://{file_metadata['file_id']}",
"size_bytes": file_metadata["size_bytes"],
"width": file_metadata["width"],
"height": file_metadata["height"]
}
new_parts.append(new_part)
attachment = {
"name": file_metadata["file_name"],
"id": file_metadata["file_id"],
"mimeType": file_metadata["mimeType"],
"size": file_metadata["size_bytes"] # 添加文件大小
}
if mime_type.startswith('image/'):
attachment.update({
"width": file_metadata["width"],
"height": file_metadata["height"]
})
elif mime_type in my_files_types:
attachment.update({"fileTokenSize": len(file_metadata["file_name"])})
attachments.append(attachment)
else:
# 确保 part 是字符串
text_part = str(part) if not isinstance(part, str) else part
new_parts.append(text_part)
content_type = "multimodal_text" if contains_image else "text"
formatted_message = {
"id": message_id,
"author": {"role": message.get("role")},
"content": {"content_type": content_type, "parts": new_parts},
"metadata": {"attachments": attachments}
}
formatted_messages.append(formatted_message)
logger.critical(f"formatted_message: {formatted_message}")
else:
# 处理单个文本消息的情况
formatted_message = {
"id": message_id,
"author": {"role": message.get("role")},
"content": {"content_type": "text", "parts": [content]},
"metadata": {}
}
formatted_messages.append(formatted_message)
# logger.debug(f"formatted_messages: {formatted_messages}")
# return
payload = {}
logger.info(f"model: {model}")
@@ -327,6 +592,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
payload = generate_gpts_payload(model, formatted_messages)
if not payload:
raise Exception('model is not accessible')
logger.debug(f"payload: {payload}")
response = requests.post(url, headers=headers, json=payload, stream=True)
# print(response)
return response
@@ -775,7 +1041,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
while not stop_event.is_set():
if time.time() - last_data_time[0] >=1:
logger.debug(f"发送保活消息")
# logger.debug(f"发送保活消息")
# 当前时间戳
timestamp = int(time.time())
new_data = {