更新自动刷新key_for_gpts_info,填入refresh_token即可

This commit is contained in:
Clivia
2024-02-06 12:12:23 +08:00
committed by GitHub
parent 7bc6a6dc6c
commit 6364eebde3

173
main.py
View File

@@ -67,7 +67,6 @@ REFRESH_TOACCESS_ENABLEOAI = REFRESH_TOACCESS.get('enableOai', 'true').lower() =
REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL = REFRESH_TOACCESS.get('ninja_refreshToAccess_Url', '')
STEAM_SLEEP_TIME = REFRESH_TOACCESS.get('steam_sleep_time', 0)
NEED_DELETE_CONVERSATION_AFTER_RESPONSE = CONFIG.get('need_delete_conversation_after_response',
'true').lower() == 'true'
@@ -162,13 +161,83 @@ def load_gpts_config(file_path):
return json.load(file)
# 官方refresh_token刷新access_token
def oaiGetAccessToken(refresh_token):
logger.info("将通过这个网址请求access_tokenhttps://auth0.openai.com/oauth/token")
url = "https://auth0.openai.com/oauth/token"
headers = {
"Content-Type": "application/json"
}
data = {
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
"grant_type": "refresh_token",
"client_id": "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh",
"refresh_token": refresh_token
}
try:
response = requests.post(url, headers=headers, json=data)
# 如果响应的状态码不是 200将引发 HTTPError 异常
response.raise_for_status()
# 拿到access_token
json_response = response.json()
access_token = json_response.get('access_token')
# 检查 access_token 是否有效
if not access_token or not access_token.startswith("eyJhb"):
logger.error("access_token 无效.")
return None
return access_token
except requests.HTTPError as http_err:
logger.error(f"HTTP error occurred: {http_err}")
except Exception as err:
logger.error(f"Other error occurred: {err}")
return None
# ninja获得access_token
def ninjaGetAccessToken(refresh_token, getAccessTokenUrl):
try:
logger.info("将通过这个网址请求access_token" + getAccessTokenUrl)
headers = {"Authorization": "Bearer " + refresh_token}
response = requests.post(getAccessTokenUrl, headers=headers)
if not response.ok:
logger.error("Request 失败: " + response.text.strip())
return None
access_token = None
try:
jsonResponse = response.json()
access_token = jsonResponse.get("access_token")
except json.JSONDecodeError:
logger.exception("Failed to decode JSON response.")
if response.status_code == 200 and access_token and access_token.startswith("eyJhb"):
return access_token
except Exception as e:
logger.exception("获取access token失败.")
return None
def updateGptsKey():
global KEY_FOR_GPTS_INFO
if not KEY_FOR_GPTS_INFO == '' and not KEY_FOR_GPTS_INFO.startswith("eyJhb"):
if REFRESH_TOACCESS_ENABLEOAI:
access_token = oaiGetAccessToken(KEY_FOR_GPTS_INFO)
else:
access_token = ninjaGetAccessToken(REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL, KEY_FOR_GPTS_INFO)
if access_token.startswith("eyJhb"):
KEY_FOR_GPTS_INFO = access_token
logging.info("KEY_FOR_GPTS_INFO被更新:" + KEY_FOR_GPTS_INFO)
# 根据 ID 发送请求并获取配置信息
def fetch_gizmo_info(base_url, proxy_api_prefix, model_id):
url = f"{base_url}{proxy_api_prefix}/backend-api/gizmos/{model_id}"
updateGptsKey()
headers = {
"Authorization": f"Bearer {KEY_FOR_GPTS_INFO}"
}
response = requests.get(url, headers=headers)
# logger.debug(f"fetch_gizmo_info_response: {response.text}")
if response.status_code == 200:
@@ -202,12 +271,17 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data):
if gizmo_info:
redis_client.set(model_id, str(gizmo_info))
logger.info(f"Cached gizmo info for {model_name}, {model_id}")
if gizmo_info:
gpts_configurations.append({
'name': model_name,
'id': model_id,
'config': gizmo_info
})
# 检查模型名称是否已经在列表中
if not any(d['name'] == model_name for d in gpts_configurations):
gpts_configurations.append({
'name': model_name,
'id': model_id,
'config': gizmo_info
})
else:
logger.info(f"Model already exists in the list, skipping...")
def generate_gpts_payload(model, messages):
@@ -242,7 +316,6 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
scheduler = APScheduler()
scheduler.init_app(app)
scheduler.start()
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
@@ -2150,68 +2223,11 @@ def count_total_input_words(messages, model):
return total_words
# 官方refresh_token刷新access_token
def oaiGetAccessToken(refresh_token):
logger.info("将通过这个网址请求access_tokenhttps://auth0.openai.com/oauth/token")
url = "https://auth0.openai.com/oauth/token"
headers = {
"Content-Type": "application/json"
}
data = {
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
"grant_type": "refresh_token",
"client_id": "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh",
"refresh_token": refresh_token
}
try:
response = requests.post(url, headers=headers, json=data)
# 如果响应的状态码不是 200将引发 HTTPError 异常
response.raise_for_status()
# 拿到access_token
json_response = response.json()
access_token = json_response.get('access_token')
# 检查 access_token 是否有效
if not access_token or not access_token.startswith("eyJhb"):
logger.error("access_token 无效.")
return None
return access_token
except requests.HTTPError as http_err:
logger.error(f"HTTP error occurred: {http_err}")
except Exception as err:
logger.error(f"Other error occurred: {err}")
return None
# ninja获得access_token
def ninjaGetAccessToken(refresh_token, getAccessTokenUrl):
try:
logger.info("将通过这个网址请求access_token" + getAccessTokenUrl)
headers = {"Authorization": "Bearer " + refresh_token}
response = requests.post(getAccessTokenUrl, headers=headers)
if not response.ok:
logger.error("Request 失败: " + response.text.strip())
return None
access_token = None
try:
jsonResponse = response.json()
access_token = jsonResponse.get("access_token")
except json.JSONDecodeError:
logger.exception("Failed to decode JSON response.")
if response.status_code == 200 and access_token and access_token.startswith("eyJhb"):
return access_token
except Exception as e:
logger.exception("获取access token失败.")
return None
# 添加缓存
def add_to_dict(key, value):
global refresh_dict
refresh_dict[key] = value
logger.info("添加access_token缓存成功.............")
import threading
@@ -2270,7 +2286,7 @@ def chat_completions():
# 启动数据处理线程
fetcher_thread = threading.Thread(target=data_fetcher, args=(
data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages))
data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages))
fetcher_thread.start()
# 启动保活线程
@@ -2445,7 +2461,7 @@ def images_generations():
# 启动数据处理线程
fetcher_thread = threading.Thread(target=data_fetcher, args=(
data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages))
data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages))
fetcher_thread.start()
# 启动保活线程
@@ -2578,11 +2594,11 @@ def get_file(filename):
return send_from_directory('files', filename)
# 内置自动刷新access_token
def updateRefresh_dict():
success_num = 0
error_num = 0
logger.info(f"==========================================")
logging.info("开始更新access_token.........")
for key in refresh_dict:
if REFRESH_TOACCESS_ENABLEOAI:
@@ -2596,8 +2612,25 @@ def updateRefresh_dict():
add_to_dict(refresh_token, access_token)
success_num += 1
logging.info("更新成功: " + str(success_num) + ", 失败: " + str(error_num))
logger.info(f"==========================================")
logging.info("开始更新KEY_FOR_GPTS_INFO.........")
updateGptsKey()
# 配置GPTS
logger.info(f"GPTS 配置信息.....................")
# 加载配置并添加到全局列表
gpts_data = load_gpts_config("./data/gpts.json")
add_config_to_global_list(BASE_URL, PROXY_API_PREFIX, gpts_data)
accessible_model_list = get_accessible_model_list()
logger.info(f"当前可用 GPTS 列表: {accessible_model_list}")
# 检查列表中是否有重复的模型名称
if len(accessible_model_list) != len(set(accessible_model_list)):
raise Exception("检测到重复的模型名称,请检查环境变量或配置文件。")
logger.info(f"==========================================")
# 在 APScheduler 中注册你的任务
# 每天3点自动刷新
scheduler.add_job(id='updateRefresh_run', func=updateRefresh_dict, trigger='cron', hour=3, minute=0)