mirror of
https://github.com/Yanyutin753/RefreshToV1Api.git
synced 2025-10-15 15:41:21 +00:00
[feat] 支持基于tiktoken的token字符统计
This commit is contained in:
@@ -13,7 +13,7 @@ ENV PYTHONUNBUFFERED=1
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
# 安装任何所需的依赖项
|
||||
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors
|
||||
RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken
|
||||
|
||||
# 在容器启动时运行 Flask 应用
|
||||
CMD ["/app/start.sh"]
|
||||
|
50
main.py
50
main.py
@@ -173,9 +173,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}})
|
||||
PANDORA_UPLOAD_URL = 'files.pandoranext.com'
|
||||
|
||||
|
||||
VERSION = '0.3.2'
|
||||
VERSION = '0.3.3'
|
||||
# VERSION = 'test'
|
||||
UPDATE_INFO = '支持bot模式非markdown图片输出'
|
||||
UPDATE_INFO = '支持基于tiktoken的token字符统计'
|
||||
# UPDATE_INFO = '【仅供临时测试使用】 '
|
||||
|
||||
with app.app_context():
|
||||
@@ -1127,6 +1127,40 @@ def keep_alive(last_data_time, stop_event, queue, model, chat_message_id):
|
||||
last_data_time[0] = time.time()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
import tiktoken
|
||||
|
||||
def count_tokens(text, model_name):
|
||||
"""
|
||||
Count the number of tokens for a given text using a specified model.
|
||||
|
||||
:param text: The text to be tokenized.
|
||||
:param model_name: The name of the model to use for tokenization.
|
||||
:return: Number of tokens in the text for the specified model.
|
||||
"""
|
||||
# 获取指定模型的编码器
|
||||
if model_name == 'gpt-3.5-turbo':
|
||||
model_name = 'gpt-3.5-turbo'
|
||||
else:
|
||||
model_name = 'gpt-4'
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
# 编码文本并计算token数量
|
||||
token_list = encoder.encode(text)
|
||||
return len(token_list)
|
||||
|
||||
def count_total_input_words(messages, model):
|
||||
"""
|
||||
Count the total number of words in all messages' content.
|
||||
"""
|
||||
total_words = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
# logger.info(f"message: {content}")
|
||||
total_words += count_tokens(content, model)
|
||||
|
||||
return total_words
|
||||
|
||||
import threading
|
||||
import time
|
||||
# 定义 Flask 路由
|
||||
@@ -1212,6 +1246,12 @@ def chat_completions():
|
||||
for _ in generate():
|
||||
pass
|
||||
# 构造响应的 JSON 结构
|
||||
ori_model_name = ''
|
||||
model_config = find_model_config(model)
|
||||
if model_config:
|
||||
ori_model_name = model_config.get('ori_name', model)
|
||||
input_tokens = count_total_input_words(messages, ori_model_name)
|
||||
comp_tokens = count_tokens(all_new_text, ori_model_name)
|
||||
response_json = {
|
||||
"id": generate_unique_id("chatcmpl"),
|
||||
"object": "chat.completion",
|
||||
@@ -1229,9 +1269,9 @@ def chat_completions():
|
||||
],
|
||||
"usage": {
|
||||
# 这里的 token 计数需要根据实际情况计算
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": comp_tokens,
|
||||
"total_tokens": input_tokens + comp_tokens
|
||||
},
|
||||
"system_fingerprint": None
|
||||
}
|
||||
|
Reference in New Issue
Block a user