适配调用team对话,提供查询ChatGPT-Account-ID的/getAccountID接口

This commit is contained in:
Yanyutin753
2024-04-04 22:01:45 +08:00
parent 6530dc6029
commit 3c9b6c12cc

99
main.py
View File

@@ -221,7 +221,6 @@ def oaiFreeGetAccessToken(getAccessTokenUrl, refresh_token):
'refresh_token': refresh_token,
}
response = requests.post(getAccessTokenUrl, data=data)
logging.info(response.text)
if not response.ok:
logger.error("Request 失败: " + response.text.strip())
return None
@@ -336,9 +335,9 @@ scheduler.start()
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'
VERSION = '0.7.9.0'
VERSION = '0.7.9.1'
# VERSION = 'test'
UPDATE_INFO = '接入oaifree'
UPDATE_INFO = '适配调用team对话,提供查询ChatGPT-Account-ID的/getAccountID接口'
# UPDATE_INFO = '【仅供临时测试使用】 '
with app.app_context():
@@ -702,7 +701,7 @@ my_files_types = [
# 定义发送请求的函数
def send_text_prompt_and_get_response(messages, api_key, stream, model, proxy_api_prefix):
def send_text_prompt_and_get_response(messages, api_key, account_id, stream, model, proxy_api_prefix):
url = f"{BASE_URL}{proxy_api_prefix}/backend-api/conversation"
headers = {
"Authorization": f"Bearer {api_key}"
@@ -920,6 +919,11 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model, proxy_ap
payload["arkose_token"] = token
# 在headers中添加新字段
headers["Openai-Sentinel-Arkose-Token"] = token
# 用于调用ChatGPT Team次数
if account_id:
headers["ChatGPT-Account-ID"] = account_id
logger.debug(f"headers: {headers}")
logger.debug(f"payload: {payload}")
response = requests.post(url, headers=headers, json=payload, stream=True)
@@ -1139,7 +1143,8 @@ def replace_sandbox(text, conversation_id, message_id, api_key, proxy_api_prefix
return replaced_text
def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, proxy_api_prefix):
def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model,
proxy_api_prefix):
all_new_text = ""
first_output = True
@@ -1372,7 +1377,7 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
if is_complete_sandbox_format(file_output_buffer):
# 替换完整的引用格式
replaced_text = replace_sandbox(file_output_buffer, conversation_id,
message_id, api_key,proxy_api_prefix)
message_id, api_key, proxy_api_prefix)
# print(replaced_text) # 输出替换后的文本
new_text = replaced_text
file_output_accumulating = False
@@ -1756,13 +1761,18 @@ def chat_completions():
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Authorization header is missing or invalid"}), 401
api_key = auth_header.split(' ')[1]
try:
api_key = auth_header.split(' ')[1].split(',')[0].strip()
account_id = auth_header.split(' ')[1].split(',')[1].strip()
logging.info(f"{api_key}:{account_id}")
except IndexError:
account_id = None
if not api_key.startswith("eyJhb"):
refresh_token = api_key
if api_key in refresh_dict:
logger.info(f"从缓存读取到api_key.........。")
api_key = refresh_dict.get(api_key)
else:
refresh_token = api_key
if REFRESH_TOACCESS_ENABLEOAI:
api_key = oaiGetAccessToken(api_key)
else:
@@ -1772,7 +1782,8 @@ def chat_completions():
add_to_dict(refresh_token, api_key)
logger.info(f"api_key: {api_key}")
upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model, proxy_api_prefix)
upstream_response = send_text_prompt_and_get_response(messages, api_key, account_id, stream, model,
proxy_api_prefix)
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
all_new_text = ""
@@ -1791,7 +1802,8 @@ def chat_completions():
# 启动数据处理线程
fetcher_thread = threading.Thread(target=data_fetcher, args=(
upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model,proxy_api_prefix))
upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model,
proxy_api_prefix))
fetcher_thread.start()
# 启动保活线程
@@ -1848,9 +1860,9 @@ def chat_completions():
fetcher_thread.join()
keep_alive_thread.join()
if conversation_id:
# print(f"准备删除的会话id {conversation_id}")
delete_conversation(conversation_id, api_key,proxy_api_prefix)
# if conversation_id:
# # print(f"准备删除的会话id {conversation_id}")
# delete_conversation(conversation_id, api_key,proxy_api_prefix)
if not stream:
# 执行流式响应的生成函数来累积 all_new_text
@@ -1924,14 +1936,19 @@ def images_generations():
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Authorization header is missing or invalid"}), 401
api_key = auth_header.split(' ')[1]
try:
api_key = auth_header.split(' ')[1].split(',')[0].strip()
account_id = auth_header.split(' ')[1].split(',')[1].strip()
logging.info(f"{api_key}:{account_id}")
except IndexError:
account_id = None
if not api_key.startswith("eyJhb"):
refresh_token = api_key
if api_key in refresh_dict:
logger.info(f"从缓存读取到api_key.........")
api_key = refresh_dict.get(api_key)
else:
if REFRESH_TOACCESS_ENABLEOAI:
refresh_token = api_key
api_key = oaiGetAccessToken(api_key)
else:
api_key = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, api_key)
@@ -1951,7 +1968,7 @@ def images_generations():
}
]
upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model,proxy_api_prefix)
upstream_response = send_text_prompt_and_get_response(messages, api_key, account_id, False, model, proxy_api_prefix)
# 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text
all_new_text = ""
@@ -2367,6 +2384,56 @@ def get_file(filename):
return send_from_directory('files', filename)
@app.route(f'/{API_PREFIX}/getAccountID' if API_PREFIX else '/getAccountID', methods=['POST'])
@cross_origin() # 使用装饰器来允许跨域请求
def getAccountID():
logger.info(f"New Img Request")
proxy_api_prefix = getPROXY_API_PREFIX(lock)
if proxy_api_prefix == None:
return jsonify({"error": "PROXY_API_PREFIX is not accessible"}), 401
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Authorization header is missing or invalid"}), 401
try:
api_key = auth_header.split(' ')[1].split(',')[0].strip()
account_id = auth_header.split(' ')[1].split(',')[1].strip()
except IndexError:
account_id = None
if not api_key.startswith("eyJhb"):
refresh_token = api_key
if api_key in refresh_dict:
logger.info(f"从缓存读取到api_key.........")
api_key = refresh_dict.get(api_key)
else:
if REFRESH_TOACCESS_ENABLEOAI:
api_key = oaiGetAccessToken(api_key)
else:
api_key = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, api_key)
if not api_key.startswith("eyJhb"):
return jsonify({"error": "refresh_token is wrong or refresh_token url is wrong!"}), 401
add_to_dict(refresh_token, api_key)
logger.info(f"api_key: {api_key}")
url = f"{BASE_URL}{proxy_api_prefix}/backend-api/accounts/check/v4-2023-04-27"
headers = {
"Authorization": "Bearer " + api_key
}
res = requests.get(url, headers=headers)
if res.status_code == 200:
data = res.json()
result = {"plus": set(), "team": set()}
for account_id, account_data in data["accounts"].items():
plan_type = account_data["account"]["plan_type"]
if plan_type == "team":
result[plan_type].add(account_id)
elif plan_type == "plus":
result[plan_type].add(account_id)
result = {plan_type: list(ids) for plan_type, ids in result.items()}
return jsonify(result)
else:
return jsonify({"error": "Request failed."}), 400
# 内置自动刷新access_token
def updateRefresh_dict():
success_num = 0