diff --git a/Dockerfile b/Dockerfile index 9a98798..7a1bd2d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ ENV PYTHONUNBUFFERED=1 RUN chmod +x /app/start.sh # 安装任何所需的依赖项 -RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors +RUN pip install --no-cache-dir flask gunicorn requests Pillow flask-cors tiktoken # 在容器启动时运行 Flask 应用 CMD ["/app/start.sh"] diff --git a/main.py b/main.py index 5be2a9d..7f88e26 100644 --- a/main.py +++ b/main.py @@ -173,9 +173,9 @@ CORS(app, resources={r"/images/*": {"origins": "*"}}) PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.3.2' +VERSION = '0.3.3' # VERSION = 'test' -UPDATE_INFO = '支持bot模式非markdown图片输出' +UPDATE_INFO = '支持基于tiktoken的token字符统计' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -1127,6 +1127,40 @@ def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): last_data_time[0] = time.time() time.sleep(1) + +import tiktoken + +def count_tokens(text, model_name): + """ + Count the number of tokens for a given text using a specified model. + + :param text: The text to be tokenized. + :param model_name: The name of the model to use for tokenization. + :return: Number of tokens in the text for the specified model. + """ + # 获取指定模型的编码器 + if model_name == 'gpt-3.5-turbo': + model_name = 'gpt-3.5-turbo' + else: + model_name = 'gpt-4' + encoder = tiktoken.encoding_for_model(model_name) + + # 编码文本并计算token数量 + token_list = encoder.encode(text) + return len(token_list) + +def count_total_input_words(messages, model): + """ + Count the total number of words in all messages' content. + """ + total_words = 0 + for message in messages: + content = message.get("content", "") + # logger.info(f"message: {content}") + total_words += count_tokens(content, model) + + return total_words + import threading import time # 定义 Flask 路由 @@ -1212,6 +1246,12 @@ def chat_completions(): for _ in generate(): pass # 构造响应的 JSON 结构 + ori_model_name = '' + model_config = find_model_config(model) + if model_config: + ori_model_name = model_config.get('ori_name', model) + input_tokens = count_total_input_words(messages, ori_model_name) + comp_tokens = count_tokens(all_new_text, ori_model_name) response_json = { "id": generate_unique_id("chatcmpl"), "object": "chat.completion", @@ -1229,9 +1269,9 @@ def chat_completions(): ], "usage": { # 这里的 token 计数需要根据实际情况计算 - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 + "prompt_tokens": input_tokens, + "completion_tokens": comp_tokens, + "total_tokens": input_tokens + comp_tokens }, "system_fingerprint": None }