mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-15 15:41:21 +00:00
[feat] 支持文件生成
This commit is contained in:
@@ -10,4 +10,5 @@ services:
|
||||
- ./log:/app/log
|
||||
- ./images:/app/images
|
||||
- ./data:/app/data
|
||||
- ./files:/app/files
|
||||
|
129
main.py
129
main.py
@@ -175,14 +175,16 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
||||
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||
|
||||
|
||||
VERSION = '0.3.9'
|
||||
VERSION = '0.4.0'
|
||||
# VERSION = 'test'
|
||||
UPDATE_INFO = '支持图片不落地,并支持base64绘图输出'
|
||||
UPDATE_INFO = '支持文件生成'
|
||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||
|
||||
with app.app_context():
|
||||
global gpts_configurations # 移到作用域的最开始
|
||||
|
||||
|
||||
|
||||
# 输出版本信息
|
||||
logger.info(f"==========================================")
|
||||
logger.info(f"Version: {VERSION}")
|
||||
@@ -211,11 +213,20 @@ with app.app_context():
|
||||
else:
|
||||
logger.info(f"pandora_api_prefix: {PROXY_API_PREFIX}")
|
||||
|
||||
if USE_OAIUSERCONTENT_URL == False:
|
||||
# 检测./images和./files文件夹是否存在,不存在则创建
|
||||
if not os.path.exists('./images'):
|
||||
os.makedirs('./images')
|
||||
if not os.path.exists('./files'):
|
||||
os.makedirs('./files')
|
||||
|
||||
|
||||
if not UPLOAD_BASE_URL:
|
||||
if USE_OAIUSERCONTENT_URL:
|
||||
logger.info("backend_container_url 未设置,将使用 oaiusercontent.com 作为图片域名")
|
||||
else:
|
||||
logger.warning("backend_container_url 未设置,图片生成功能将无法正常使用")
|
||||
|
||||
|
||||
else:
|
||||
logger.info(f"backend_container_url: {UPLOAD_BASE_URL}")
|
||||
@@ -752,6 +763,81 @@ def replace_complete_citation(text, citations):
|
||||
|
||||
return replaced_text, remaining_text, is_potential_citation
|
||||
|
||||
def is_valid_sandbox_combined_corrected_final_v2(text):
|
||||
# 更新正则表达式以包含所有合法格式
|
||||
patterns = [
|
||||
r'.*\(sandbox:\/[^)]*\)?', # sandbox 后跟路径,包括不完整路径
|
||||
r'.*\(', # 只有 "(" 也视为合法格式
|
||||
r'.*\(sandbox(:|$)', # 匹配 "(sandbox" 或 "(sandbox:",确保后面不跟其他字符或字符串结束
|
||||
]
|
||||
|
||||
# 检查文本是否符合任一合法格式
|
||||
return any(bool(re.fullmatch(pattern, text)) for pattern in patterns)
|
||||
|
||||
|
||||
|
||||
def is_complete_sandbox_format(text):
|
||||
# 完整格式应该类似于 (sandbox:/xx/xx/xx 或 (sandbox:/xx/xx)
|
||||
pattern = r'.*\(sandbox\:\/[^)]+\)'
|
||||
return bool(re.fullmatch(pattern, text))
|
||||
|
||||
import urllib.parse
|
||||
|
||||
def replace_sandbox(text, conversation_id, message_id, api_key):
|
||||
def replace_match(match):
|
||||
sandbox_path = match.group(1)
|
||||
download_url = get_download_url(conversation_id, message_id, sandbox_path)
|
||||
file_name = extract_filename(download_url)
|
||||
timestamped_file_name = timestamp_filename(file_name)
|
||||
if USE_OAIUSERCONTENT_URL == False:
|
||||
download_file(download_url, timestamped_file_name)
|
||||
return f"({UPLOAD_BASE_URL}/files/{timestamped_file_name})"
|
||||
else:
|
||||
return f"({download_url})"
|
||||
|
||||
def get_download_url(conversation_id, message_id, sandbox_path):
|
||||
# 模拟发起请求以获取下载 URL
|
||||
sandbox_info_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation/{conversation_id}/interpreter/download?message_id={message_id}&sandbox_path={sandbox_path}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
response = requests.get(sandbox_info_url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("download_url")
|
||||
else:
|
||||
logger.error(f"获取下载 URL 失败: {response.text}")
|
||||
return None
|
||||
|
||||
def extract_filename(url):
|
||||
# 从 URL 中提取 filename 参数
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
filename = query_params.get("rscd", [""])[0].split("filename=")[-1]
|
||||
return filename
|
||||
|
||||
def timestamp_filename(filename):
|
||||
# 在文件名前加上当前时间戳
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
return f"{timestamp}_{filename}"
|
||||
|
||||
def download_file(download_url, filename):
|
||||
# 下载并保存文件
|
||||
# 确保 ./files 目录存在
|
||||
if not os.path.exists("./files"):
|
||||
os.makedirs("./files")
|
||||
file_path = f"./files/{filename}"
|
||||
with requests.get(download_url, stream=True) as r:
|
||||
with open(file_path, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
# 替换 (sandbox:xxx) 格式的文本
|
||||
replaced_text = re.sub(r'\(sandbox:([^)]+)\)', replace_match, text)
|
||||
return replaced_text
|
||||
|
||||
def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model):
|
||||
all_new_text = ""
|
||||
|
||||
@@ -768,6 +854,8 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
|
||||
conversation_id = ''
|
||||
citation_buffer = ""
|
||||
citation_accumulating = False
|
||||
file_output_buffer = ""
|
||||
file_output_accumulating = False
|
||||
for chunk in upstream_response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
buffer += chunk.decode('utf-8')
|
||||
@@ -795,6 +883,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
|
||||
if message == {} or message == None:
|
||||
logger.debug(f"message 为空: data_json: {data_json}")
|
||||
|
||||
message_id = message.get("id")
|
||||
message_status = message.get("status")
|
||||
content = message.get("content", {})
|
||||
role = message.get("author", {}).get("role")
|
||||
@@ -962,6 +1051,34 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
|
||||
citation_accumulating = False
|
||||
citation_buffer = ""
|
||||
|
||||
if "(" in new_text and not file_output_accumulating and not citation_accumulating:
|
||||
file_output_accumulating = True
|
||||
file_output_buffer = file_output_buffer + new_text
|
||||
logger.debug(f"开始积累文件输出: {file_output_buffer}")
|
||||
elif file_output_accumulating:
|
||||
file_output_buffer += new_text
|
||||
logger.debug(f"积累文件输出: {file_output_buffer}")
|
||||
if file_output_accumulating:
|
||||
if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer):
|
||||
logger.debug(f"合法文件输出格式: {file_output_buffer}")
|
||||
# 继续积累
|
||||
if is_complete_sandbox_format(file_output_buffer):
|
||||
# 替换完整的引用格式
|
||||
replaced_text = replace_sandbox(file_output_buffer, conversation_id, message_id, api_key)
|
||||
# print(replaced_text) # 输出替换后的文本
|
||||
new_text = replaced_text
|
||||
file_output_accumulating = False
|
||||
file_output_buffer = ""
|
||||
logger.debug(f"替换完整的文件输出格式: {new_text}")
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
# 不是合法格式,放弃积累并响应
|
||||
logger.debug(f"不合法格式: {file_output_buffer}")
|
||||
new_text = file_output_buffer
|
||||
file_output_accumulating = False
|
||||
file_output_buffer = ""
|
||||
|
||||
|
||||
# Python 工具执行输出特殊处理
|
||||
if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None:
|
||||
@@ -1789,6 +1906,14 @@ def get_image(filename):
|
||||
return "文件不存在哦!", 404
|
||||
return send_from_directory('images', filename)
|
||||
|
||||
@app.route('/files/<filename>')
|
||||
@cross_origin() # 使用装饰器来允许跨域请求
|
||||
def get_file(filename):
|
||||
# 检查文件是否存在
|
||||
if not os.path.isfile(os.path.join('files', filename)):
|
||||
return "文件不存在哦!", 404
|
||||
return send_from_directory('files', filename)
|
||||
|
||||
# 运行 Flask 应用
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0')
|
||||
|
Reference in New Issue
Block a user