From 6364eebde360216f837046467b95da5dac03dcaf Mon Sep 17 00:00:00 2001 From: Clivia <132346501+Yanyutin753@users.noreply.github.com> Date: Tue, 6 Feb 2024 12:12:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=87=AA=E5=8A=A8=E5=88=B7?= =?UTF-8?q?=E6=96=B0key=5Ffor=5Fgpts=5Finfo,=E5=A1=AB=E5=85=A5refresh=5Fto?= =?UTF-8?q?ken=E5=8D=B3=E5=8F=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 177 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 105 insertions(+), 72 deletions(-) diff --git a/main.py b/main.py index 72d163c..22e2e5f 100644 --- a/main.py +++ b/main.py @@ -67,7 +67,6 @@ REFRESH_TOACCESS_ENABLEOAI = REFRESH_TOACCESS.get('enableOai', 'true').lower() = REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL = REFRESH_TOACCESS.get('ninja_refreshToAccess_Url', '') STEAM_SLEEP_TIME = REFRESH_TOACCESS.get('steam_sleep_time', 0) - NEED_DELETE_CONVERSATION_AFTER_RESPONSE = CONFIG.get('need_delete_conversation_after_response', 'true').lower() == 'true' @@ -162,13 +161,83 @@ def load_gpts_config(file_path): return json.load(file) +# 官方refresh_token刷新access_token +def oaiGetAccessToken(refresh_token): + logger.info("将通过这个网址请求access_token:https://auth0.openai.com/oauth/token") + url = "https://auth0.openai.com/oauth/token" + headers = { + "Content-Type": "application/json" + } + data = { + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + "grant_type": "refresh_token", + "client_id": "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh", + "refresh_token": refresh_token + } + try: + response = requests.post(url, headers=headers, json=data) + # 如果响应的状态码不是 200,将引发 HTTPError 异常 + response.raise_for_status() + + # 拿到access_token + json_response = response.json() + access_token = json_response.get('access_token') + + # 检查 access_token 是否有效 + if not access_token or not access_token.startswith("eyJhb"): + logger.error("access_token 无效.") + return None + + return access_token + + except requests.HTTPError as http_err: + logger.error(f"HTTP error occurred: {http_err}") + except Exception as err: + logger.error(f"Other error occurred: {err}") + return None + + +# ninja获得access_token +def ninjaGetAccessToken(refresh_token, getAccessTokenUrl): + try: + logger.info("将通过这个网址请求access_token:" + getAccessTokenUrl) + headers = {"Authorization": "Bearer " + refresh_token} + response = requests.post(getAccessTokenUrl, headers=headers) + if not response.ok: + logger.error("Request 失败: " + response.text.strip()) + return None + access_token = None + try: + jsonResponse = response.json() + access_token = jsonResponse.get("access_token") + except json.JSONDecodeError: + logger.exception("Failed to decode JSON response.") + if response.status_code == 200 and access_token and access_token.startswith("eyJhb"): + return access_token + except Exception as e: + logger.exception("获取access token失败.") + return None + + +def updateGptsKey(): + global KEY_FOR_GPTS_INFO + if not KEY_FOR_GPTS_INFO == '' and not KEY_FOR_GPTS_INFO.startswith("eyJhb"): + if REFRESH_TOACCESS_ENABLEOAI: + access_token = oaiGetAccessToken(KEY_FOR_GPTS_INFO) + else: + access_token = ninjaGetAccessToken(REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL, KEY_FOR_GPTS_INFO) + if access_token.startswith("eyJhb"): + KEY_FOR_GPTS_INFO = access_token + logging.info("KEY_FOR_GPTS_INFO被更新:" + KEY_FOR_GPTS_INFO) + + # 根据 ID 发送请求并获取配置信息 def fetch_gizmo_info(base_url, proxy_api_prefix, model_id): url = f"{base_url}{proxy_api_prefix}/backend-api/gizmos/{model_id}" + updateGptsKey() headers = { "Authorization": f"Bearer {KEY_FOR_GPTS_INFO}" } - response = requests.get(url, headers=headers) # logger.debug(f"fetch_gizmo_info_response: {response.text}") if response.status_code == 200: @@ -202,12 +271,17 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data): if gizmo_info: redis_client.set(model_id, str(gizmo_info)) logger.info(f"Cached gizmo info for {model_name}, {model_id}") + if gizmo_info: - gpts_configurations.append({ - 'name': model_name, - 'id': model_id, - 'config': gizmo_info - }) + # 检查模型名称是否已经在列表中 + if not any(d['name'] == model_name for d in gpts_configurations): + gpts_configurations.append({ + 'name': model_name, + 'id': model_id, + 'config': gizmo_info + }) + else: + logger.info(f"Model already exists in the list, skipping...") def generate_gpts_payload(model, messages): @@ -242,7 +316,6 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) scheduler = APScheduler() scheduler.init_app(app) scheduler.start() - # PANDORA_UPLOAD_URL = 'files.pandoranext.com' @@ -338,11 +411,11 @@ with app.app_context(): logger.info(f"enabled_plugin_output: {BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT}") # ninjaToV1Api_refresh - + logger.info(f"REFRESH_TOACCESS_ENABLEOAI: {REFRESH_TOACCESS_ENABLEOAI}") logger.info(f"REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL: {REFRESH_TOACCESS_NINJA_REFRESHTOACCESS_URL}") logger.info(f"STEAM_SLEEP_TIME: {STEAM_SLEEP_TIME}") - + if not BASE_URL: raise Exception('upstream_base_url is not set') else: @@ -2150,68 +2223,11 @@ def count_total_input_words(messages, model): return total_words -# 官方refresh_token刷新access_token -def oaiGetAccessToken(refresh_token): - logger.info("将通过这个网址请求access_token:https://auth0.openai.com/oauth/token") - url = "https://auth0.openai.com/oauth/token" - headers = { - "Content-Type": "application/json" - } - data = { - "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", - "grant_type": "refresh_token", - "client_id": "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh", - "refresh_token": refresh_token - } - try: - response = requests.post(url, headers=headers, json=data) - # 如果响应的状态码不是 200,将引发 HTTPError 异常 - response.raise_for_status() - - # 拿到access_token - json_response = response.json() - access_token = json_response.get('access_token') - - # 检查 access_token 是否有效 - if not access_token or not access_token.startswith("eyJhb"): - logger.error("access_token 无效.") - return None - - return access_token - - except requests.HTTPError as http_err: - logger.error(f"HTTP error occurred: {http_err}") - except Exception as err: - logger.error(f"Other error occurred: {err}") - return None - - -# ninja获得access_token -def ninjaGetAccessToken(refresh_token, getAccessTokenUrl): - try: - logger.info("将通过这个网址请求access_token:" + getAccessTokenUrl) - headers = {"Authorization": "Bearer " + refresh_token} - response = requests.post(getAccessTokenUrl, headers=headers) - if not response.ok: - logger.error("Request 失败: " + response.text.strip()) - return None - access_token = None - try: - jsonResponse = response.json() - access_token = jsonResponse.get("access_token") - except json.JSONDecodeError: - logger.exception("Failed to decode JSON response.") - if response.status_code == 200 and access_token and access_token.startswith("eyJhb"): - return access_token - except Exception as e: - logger.exception("获取access token失败.") - return None - - # 添加缓存 def add_to_dict(key, value): global refresh_dict refresh_dict[key] = value + logger.info("添加access_token缓存成功.............") import threading @@ -2270,7 +2286,7 @@ def chat_completions(): # 启动数据处理线程 fetcher_thread = threading.Thread(target=data_fetcher, args=( - data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages)) + data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages)) fetcher_thread.start() # 启动保活线程 @@ -2445,7 +2461,7 @@ def images_generations(): # 启动数据处理线程 fetcher_thread = threading.Thread(target=data_fetcher, args=( - data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) + data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) fetcher_thread.start() # 启动保活线程 @@ -2578,11 +2594,11 @@ def get_file(filename): return send_from_directory('files', filename) - # 内置自动刷新access_token def updateRefresh_dict(): success_num = 0 error_num = 0 + logger.info(f"==========================================") logging.info("开始更新access_token.........") for key in refresh_dict: if REFRESH_TOACCESS_ENABLEOAI: @@ -2596,8 +2612,25 @@ def updateRefresh_dict(): add_to_dict(refresh_token, access_token) success_num += 1 logging.info("更新成功: " + str(success_num) + ", 失败: " + str(error_num)) + logger.info(f"==========================================") + logging.info("开始更新KEY_FOR_GPTS_INFO.........") + updateGptsKey() + # 配置GPTS + logger.info(f"GPTS 配置信息.....................") + # 加载配置并添加到全局列表 + gpts_data = load_gpts_config("./data/gpts.json") + add_config_to_global_list(BASE_URL, PROXY_API_PREFIX, gpts_data) + + accessible_model_list = get_accessible_model_list() + logger.info(f"当前可用 GPTS 列表: {accessible_model_list}") + + # 检查列表中是否有重复的模型名称 + if len(accessible_model_list) != len(set(accessible_model_list)): + raise Exception("检测到重复的模型名称,请检查环境变量或配置文件。") + + logger.info(f"==========================================") + -# 在 APScheduler 中注册你的任务 # 每天3点自动刷新 scheduler.add_job(id='updateRefresh_run', func=updateRefresh_dict, trigger='cron', hour=3, minute=0)