mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-12-13 02:00:14 +08:00
[feat] 支持基于tiktoken的token字符统计
This commit is contained in:
@@ -13,7 +13,7 @@ ENV PYTHONUNBUFFERED=1
|
|||||||
RUN chmod +x /app/start.sh
|
RUN chmod +x /app/start.sh
|
||||||
|
|
||||||
# 安装任何所需的依赖项
|
# 安装任何所需的依赖项
|
||||||
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors
|
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken
|
||||||
|
|
||||||
# 在容器启动时运行 Flask 应用
|
# 在容器启动时运行 Flask 应用
|
||||||
CMD ["/app/start.sh"]
|
CMD ["/app/start.sh"]
|
||||||
|
|||||||
50
main.py
50
main.py
@@ -173,9 +173,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
|||||||
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||||
|
|
||||||
|
|
||||||
VERSION = '0.3.2'
|
VERSION = '0.3.3'
|
||||||
# VERSION = 'test'
|
# VERSION = 'test'
|
||||||
UPDATE_INFO = '支持bot模式非markdown图片输出'
|
UPDATE_INFO = '支持基于tiktoken的token字符统计'
|
||||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
@@ -1127,6 +1127,40 @@ def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
|||||||
last_data_time[0] = time.time()
|
last_data_time[0] = time.time()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
def count_tokens(text, model_name):
|
||||||
|
"""
|
||||||
|
Count the number of tokens for a given text using a specified model.
|
||||||
|
|
||||||
|
:param text: The text to be tokenized.
|
||||||
|
:param model_name: The name of the model to use for tokenization.
|
||||||
|
:return: Number of tokens in the text for the specified model.
|
||||||
|
"""
|
||||||
|
# 获取指定模型的编码器
|
||||||
|
if model_name == 'gpt-3.5-turbo':
|
||||||
|
model_name = 'gpt-3.5-turbo'
|
||||||
|
else:
|
||||||
|
model_name = 'gpt-4'
|
||||||
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
|
# 编码文本并计算token数量
|
||||||
|
token_list = encoder.encode(text)
|
||||||
|
return len(token_list)
|
||||||
|
|
||||||
|
def count_total_input_words(messages, model):
|
||||||
|
"""
|
||||||
|
Count the total number of words in all messages' content.
|
||||||
|
"""
|
||||||
|
total_words = 0
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content", "")
|
||||||
|
# logger.info(f"message: {content}")
|
||||||
|
total_words += count_tokens(content, model)
|
||||||
|
|
||||||
|
return total_words
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
# 定义 Flask 路由
|
# 定义 Flask 路由
|
||||||
@@ -1212,6 +1246,12 @@ def chat_completions():
|
|||||||
for _ in generate():
|
for _ in generate():
|
||||||
pass
|
pass
|
||||||
# 构造响应的 JSON 结构
|
# 构造响应的 JSON 结构
|
||||||
|
ori_model_name = ''
|
||||||
|
model_config = find_model_config(model)
|
||||||
|
if model_config:
|
||||||
|
ori_model_name = model_config.get('ori_name', model)
|
||||||
|
input_tokens = count_total_input_words(messages, ori_model_name)
|
||||||
|
comp_tokens = count_tokens(all_new_text, ori_model_name)
|
||||||
response_json = {
|
response_json = {
|
||||||
"id": generate_unique_id("chatcmpl"),
|
"id": generate_unique_id("chatcmpl"),
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@@ -1229,9 +1269,9 @@ def chat_completions():
|
|||||||
],
|
],
|
||||||
"usage": {
|
"usage": {
|
||||||
# 这里的 token 计数需要根据实际情况计算
|
# 这里的 token 计数需要根据实际情况计算
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": input_tokens,
|
||||||
"completion_tokens": 0,
|
"completion_tokens": comp_tokens,
|
||||||
"total_tokens": 0
|
"total_tokens": input_tokens + comp_tokens
|
||||||
},
|
},
|
||||||
"system_fingerprint": None
|
"system_fingerprint": None
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user