[feat] 重磅更新:支持大部分GPTS模型

This commit is contained in:
Wizerd
2023-12-15 22:40:16 +08:00
parent 5da04bc35d
commit 5b4e2a31f3
4 changed files with 153 additions and 108 deletions

210
main.py
View File

@@ -17,6 +17,85 @@ def generate_unique_id(prefix):
return unique_id
def get_accessible_model_list():
return [config['name'] for config in gpts_configurations]
def find_model_config(model_name):
for config in gpts_configurations:
if config['name'] == model_name:
return config
return None
# 从 gpts.json 读取配置
def load_gpts_config(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
# 根据 ID 发送请求并获取配置信息
def fetch_gizmo_info(base_url, proxy_api_prefix, model_id):
url = f"{base_url}/{proxy_api_prefix}/backend-api/gizmos/{model_id}"
headers = {
"Authorization": f"Bearer {KEY_FOR_GPTS_INFO}"
}
response = requests.get(url, headers=headers)
# print(f"fetch_gizmo_info_response: {response.text}")
if response.status_code == 200:
return response.json()
else:
return None
gpts_configurations = [
{
"name":"gpt-4-s"
},
{
"name":"gpt-4-mobile"
}
]
# 将配置添加到全局列表
def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data):
global gpts_configurations
# print(f"gpts_data: {gpts_data}")
for model_name, model_info in gpts_data.items():
# print(f"model_name: {model_name}")
# print(f"model_info: {model_info}")
model_id = model_info['id']
gizmo_info = fetch_gizmo_info(base_url, proxy_api_prefix, model_id)
if gizmo_info:
gpts_configurations.append({
'name': model_name,
'id': model_id,
'config': gizmo_info
})
def generate_gpts_payload(model):
model_config = find_model_config(model)
if model_config:
gizmo_info = model_config['config']
gizmo_id = gizmo_info['gizmo']['id']
payload = {
"action": "next",
"messages": [],
"parent_message_id": str(uuid.uuid4()),
"model": "gpt-4-gizmo",
"timezone_offset_min": -480,
"history_and_training_disabled": False,
"conversation_mode": {
"gizmo": gizmo_info,
"kind": "gizmo_interaction",
"gizmo_id": gizmo_id
},
"force_paragen": False,
"force_rate_limit": False
}
return payload
else:
return None
# 创建 Flask 应用
app = Flask(__name__)
@@ -25,9 +104,10 @@ app = Flask(__name__)
BASE_URL = os.getenv('BASE_URL', '')
PROXY_API_PREFIX = os.getenv('PROXY_API_PREFIX', '')
UPLOAD_BASE_URL = os.getenv('UPLOAD_BASE_URL', '')
KEY_FOR_GPTS_INFO = os.getenv('KEY_FOR_GPTS_INFO', '')
VERSION = '0.0.11'
UPDATE_INFO = '修复一些偶现的bug'
VERSION = '0.1.0'
UPDATE_INFO = '适配大部分GPTS模型'
with app.app_context():
# 输出版本信息
@@ -47,6 +127,21 @@ with app.app_context():
print(f"==========================================")
print(f"GPTS 配置信息")
# 加载配置并添加到全局列表
gpts_data = load_gpts_config("./gpts.json")
add_config_to_global_list(BASE_URL, PROXY_API_PREFIX, gpts_data)
# print("当前可用GPTS" + get_accessible_model_list())
# 输出当前可用 GPTS name
print(f"当前可用 GPTS 列表: {get_accessible_model_list()}")
print(f"==========================================")
# print(f"GPTs Payload 生成测试")
# print(f"gpt-4-classic: {generate_gpts_payload('gpt-4-classic')}")
# 定义获取 token 的函数
def get_token():
@@ -59,7 +154,8 @@ def get_token():
return None
import os
accessable_model_list = ['gpt-4-classic', 'gpt-4-s', 'gpt-4-mobile']
# 定义发送请求的函数
def send_text_prompt_and_get_response(messages, api_key, stream, model):
@@ -82,103 +178,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
payload = {}
print(f"model: {model}")
if model == 'gpt-4-classic':
payload = {
# 构建 payload
"action": "next",
"messages": formatted_messages,
"parent_message_id": str(uuid.uuid4()),
"model": "gpt-4-gizmo",
"timezone_offset_min": -480,
"history_and_training_disabled": False,
"conversation_mode": {
"gizmo": {
"gizmo": {
"id": "g-YyyyMT9XH",
"organization_id": "org-OROoM5KiDq6bcfid37dQx4z4",
"short_url": "g-YyyyMT9XH-chatgpt-classic",
"author": {
"user_id": "user-u7SVk5APwT622QC7DPe41GHJ",
"display_name": "ChatGPT",
"link_to":None,
"selected_display": "name",
"is_verified":True
},
"voice": {
"id": "ember"
},
"workspace_id":None,
"model":None,
"instructions":None,
"settings":None,
"display": {
"name": "ChatGPT Classic",
"description": "The latest version of GPT-4 with no additional capabilities",
"welcome_message": "Hello",
"prompt_starters":None,
"profile_picture_url": "",
"categories": []
},
"share_recipient": "marketplace",
"updated_at": "2023-11-26T17:46:07.341305+00:00",
"last_interacted_at": "2023-12-11T09:49:34.943245+00:00",
"tags": [
"public",
"first_party"
],
"version":None,
"live_version":None,
"training_disabled":None,
"allowed_sharing_recipients":None,
"review_info":None,
"appeal_info":None,
"vanity_metrics":None
},
"tools": [],
"files": [],
"product_features": {
"attachments": {
"type": "retrieval",
"accepted_mime_types": [
"text/x-script.python",
"application/x-latext",
"text/x-c++",
"text/javascript",
"text/x-java",
"text/x-typescript",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"text/x-csharp",
"text/plain",
"application/pdf",
"text/x-sh",
"text/markdown",
"text/x-c",
"text/x-ruby",
"text/x-tex",
"text/x-php",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/json",
"text/html",
"application/msword"
],
"image_mime_types": [
"image/webp",
"image/jpeg",
"image/png",
"image/gif"
],
"can_accept_all_mime_types":True
}
}
},
"kind": "gizmo_interaction",
"gizmo_id": "g-YyyyMT9XH"
},
"force_paragen":False,
"force_rate_limit":False
}
elif model == 'gpt-4-s':
if model == 'gpt-4-s':
payload = {
# 构建 payload
"action": "next",
@@ -202,6 +202,10 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
"history_and_training_disabled": False,
"conversation_mode":{"kind":"primary_assistant"},"force_paragen":False,"force_rate_limit":False
}
else:
payload = generate_gpts_payload(model)
if not payload:
raise Exception('model is not accessible')
response = requests.post(url, headers=headers, json=payload, stream=True)
# print(response)
return response
@@ -325,8 +329,10 @@ def chat_completions():
data = request.json
messages = data.get('messages')
model = data.get('model')
if model not in accessable_model_list:
return jsonify({"error": "model is not accessable"}), 401
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list:
return jsonify({"error": "model is not accessible"}), 401
stream = data.get('stream', False)
auth_header = request.headers.get('Authorization')