[feat] 支持文件生成

This commit is contained in:
Wizerd
2024-01-02 11:28:59 +08:00
parent 95348eb662
commit cc89d65136
2 changed files with 128 additions and 2 deletions

View File

@@ -10,4 +10,5 @@ services:
- ./log:/app/log
- ./images:/app/images
- ./data:/app/data
- ./files:/app/files

129
main.py
View File

@@ -175,14 +175,16 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.3.9'
VERSION = '0.4.0'
# VERSION = 'test'
UPDATE_INFO = '支持图片不落地并支持base64绘图输出'
UPDATE_INFO = '支持文件生成'
# UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context():
global gpts_configurations # 移到作用域的最开始
# 输出版本信息
logger.info(f"==========================================")
logger.info(f"Version: {VERSION}")
@@ -211,11 +213,20 @@ with app.app_context():
else:
logger.info(f"pandora_api_prefix: {PROXY_API_PREFIX}")
if USE_OAIUSERCONTENT_URL == False:
# 检测./images和./files文件夹是否存在不存在则创建
if not os.path.exists('./images'):
os.makedirs('./images')
if not os.path.exists('./files'):
os.makedirs('./files')
if not UPLOAD_BASE_URL:
if USE_OAIUSERCONTENT_URL:
logger.info("backend_container_url 未设置,将使用 oaiusercontent.com 作为图片域名")
else:
logger.warning("backend_container_url 未设置,图片生成功能将无法正常使用")
else:
logger.info(f"backend_container_url: {UPLOAD_BASE_URL}")
@@ -752,6 +763,81 @@ def replace_complete_citation(text, citations):
return replaced_text, remaining_text, is_potential_citation
def is_valid_sandbox_combined_corrected_final_v2(text):
# 更新正则表达式以包含所有合法格式
patterns = [
r'.*\(sandbox:\/[^)]*\)?', # sandbox 后跟路径,包括不完整路径
r'.*\(', # 只有 "(" 也视为合法格式
r'.*\(sandbox(:|$)', # 匹配 "(sandbox" 或 "(sandbox:",确保后面不跟其他字符或字符串结束
]
# 检查文本是否符合任一合法格式
return any(bool(re.fullmatch(pattern, text)) for pattern in patterns)
def is_complete_sandbox_format(text):
# 完整格式应该类似于 (sandbox:/xx/xx/xx 或 (sandbox:/xx/xx)
pattern = r'.*\(sandbox\:\/[^)]+\)'
return bool(re.fullmatch(pattern, text))
import urllib.parse
def replace_sandbox(text, conversation_id, message_id, api_key):
def replace_match(match):
sandbox_path = match.group(1)
download_url = get_download_url(conversation_id, message_id, sandbox_path)
file_name = extract_filename(download_url)
timestamped_file_name = timestamp_filename(file_name)
if USE_OAIUSERCONTENT_URL == False:
download_file(download_url, timestamped_file_name)
return f"({UPLOAD_BASE_URL}/files/{timestamped_file_name})"
else:
return f"({download_url})"
def get_download_url(conversation_id, message_id, sandbox_path):
# 模拟发起请求以获取下载 URL
sandbox_info_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation/{conversation_id}/interpreter/download?message_id={message_id}&sandbox_path={sandbox_path}"
headers = {
"Authorization": f"Bearer {api_key}"
}
response = requests.get(sandbox_info_url, headers=headers)
if response.status_code == 200:
return response.json().get("download_url")
else:
logger.error(f"获取下载 URL 失败: {response.text}")
return None
def extract_filename(url):
# 从 URL 中提取 filename 参数
parsed_url = urllib.parse.urlparse(url)
query_params = urllib.parse.parse_qs(parsed_url.query)
filename = query_params.get("rscd", [""])[0].split("filename=")[-1]
return filename
def timestamp_filename(filename):
# 在文件名前加上当前时间戳
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f"{timestamp}_{filename}"
def download_file(download_url, filename):
# 下载并保存文件
# 确保 ./files 目录存在
if not os.path.exists("./files"):
os.makedirs("./files")
file_path = f"./files/{filename}"
with requests.get(download_url, stream=True) as r:
with open(file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
# 替换 (sandbox:xxx) 格式的文本
replaced_text = re.sub(r'\(sandbox:([^)]+)\)', replace_match, text)
return replaced_text
def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model):
all_new_text = ""
@@ -768,6 +854,8 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
conversation_id = ''
citation_buffer = ""
citation_accumulating = False
file_output_buffer = ""
file_output_accumulating = False
for chunk in upstream_response.iter_content(chunk_size=1024):
if chunk:
buffer += chunk.decode('utf-8')
@@ -795,6 +883,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
if message == {} or message == None:
logger.debug(f"message 为空: data_json: {data_json}")
message_id = message.get("id")
message_status = message.get("status")
content = message.get("content", {})
role = message.get("author", {}).get("role")
@@ -962,6 +1051,34 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
citation_accumulating = False
citation_buffer = ""
if "(" in new_text and not file_output_accumulating and not citation_accumulating:
file_output_accumulating = True
file_output_buffer = file_output_buffer + new_text
logger.debug(f"开始积累文件输出: {file_output_buffer}")
elif file_output_accumulating:
file_output_buffer += new_text
logger.debug(f"积累文件输出: {file_output_buffer}")
if file_output_accumulating:
if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer):
logger.debug(f"合法文件输出格式: {file_output_buffer}")
# 继续积累
if is_complete_sandbox_format(file_output_buffer):
# 替换完整的引用格式
replaced_text = replace_sandbox(file_output_buffer, conversation_id, message_id, api_key)
# print(replaced_text) # 输出替换后的文本
new_text = replaced_text
file_output_accumulating = False
file_output_buffer = ""
logger.debug(f"替换完整的文件输出格式: {new_text}")
else:
continue
else:
# 不是合法格式,放弃积累并响应
logger.debug(f"不合法格式: {file_output_buffer}")
new_text = file_output_buffer
file_output_accumulating = False
file_output_buffer = ""
# Python 工具执行输出特殊处理
if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None:
@@ -1789,6 +1906,14 @@ def get_image(filename):
return "文件不存在哦!", 404
return send_from_directory('images', filename)
@app.route('/files/<filename>')
@cross_origin() # 使用装饰器来允许跨域请求
def get_file(filename):
# 检查文件是否存在
if not os.path.isfile(os.path.join('files', filename)):
return "文件不存在哦!", 404
return send_from_directory('files', filename)
# 运行 Flask 应用
if __name__ == '__main__':
app.run(host='0.0.0.0')