feat: chat content use tiktoken count

This commit is contained in:
archer
2023-04-24 16:46:39 +08:00
parent adbaa8b37b
commit 1f112f7715
23 changed files with 182 additions and 836 deletions

View File

@@ -5,7 +5,7 @@ import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from './tools';
import { User } from '../models/user';
import { formatPrice } from '@/utils/user';
import { ChatModelNameEnum } from '@/constants/model';
import { embeddingModel } from '@/constants/model';
import { pushGenerateVectorBill } from '../events/pushBill';
/* 获取用户 api 的 openai 信息 */
@@ -80,7 +80,7 @@ export const openaiCreateEmbedding = async ({
const res = await chatAPI
.createEmbedding(
{
model: ChatModelNameEnum.VECTOR,
model: embeddingModel,
input: text
},
{
@@ -134,11 +134,11 @@ export const gpt35StreamResponse = ({
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
// console.log('content:', content);
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
if (!stream.destroyed && content) {
stream.push(content.replace(/\n/g, '<br/>'));
}
} catch (error) {
error;
}

View File

@@ -2,10 +2,12 @@ import type { NextApiRequest } from 'next';
import crypto from 'crypto';
import jwt from 'jsonwebtoken';
import { ChatItemType } from '@/types/chat';
import { encode } from 'gpt-token-utils';
import { OpenApi, User } from '../mongo';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { countChatTokens } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatModelEnum } from '@/constants/model';
/* 密码加密 */
export const hashPassword = (psw: string) => {
@@ -86,8 +88,16 @@ export const authOpenApiKey = async (req: NextApiRequest) => {
export const httpsAgent = (fast: boolean) =>
fast ? global.httpsAgentFast : global.httpsAgentNormal;
/* tokens 截断 */
export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) => {
/* 聊天内容 tokens 截断 */
export const openaiChatFilter = ({
model,
prompts,
maxTokens
}: {
model: `${ChatModelEnum}`;
prompts: ChatItemType[];
maxTokens: number;
}) => {
const formatPrompts = prompts.map((item) => ({
obj: item.obj,
value: item.value
@@ -97,41 +107,64 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
.trim()
}));
let res: ChatItemType[] = [];
let chats: ChatItemType[] = [];
let systemPrompt: ChatItemType | null = null;
// System 词保留
if (formatPrompts[0]?.obj === 'SYSTEM') {
systemPrompt = formatPrompts.shift() as ChatItemType;
maxTokens -= encode(formatPrompts[0].value).length;
}
// 从后往前截取
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
const tokens = encode(formatPrompts[i].value).length;
res.unshift(formatPrompts[i]);
chats.unshift(formatPrompts[i]);
messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({
role: map[item.obj],
content: item.value
}));
const tokens = countChatTokens({
model,
messages
});
/* 整体 tokens 超出范围 */
if (tokens >= maxTokens) {
break;
}
maxTokens -= tokens;
}
return systemPrompt ? [systemPrompt, ...res] : res;
return messages;
};
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
export const systemPromptFilter = ({
model,
prompts,
maxTokens
}: {
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
prompts: string[];
maxTokens: number;
}) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
const prompt = prompts[i].replace(/\n+/g, '\n');
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
if (tokens >= maxTokens) {
break;
}