perf: not cut text when little text

This commit is contained in:
archer
2023-04-25 00:16:52 +08:00
parent 3294be5e7f
commit ce68791c3c

View File

@@ -6,7 +6,7 @@ import { OpenApi, User } from '../mongo';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { countChatTokens } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
import { ChatModelEnum } from '@/constants/model';
/* 密码加密 */
@@ -88,6 +88,13 @@ export const authOpenApiKey = async (req: NextApiRequest) => {
export const httpsAgent = (fast: boolean) =>
fast ? global.httpsAgentFast : global.httpsAgentNormal;
/* delete invalid symbol */
const simplifyStr = (str: string) =>
str
.replace(/\n+/g, '\n') // 连续空行
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
.trim();
/* 聊天内容 tokens 截断 */
export const openaiChatFilter = ({
model,
@@ -98,40 +105,44 @@ export const openaiChatFilter = ({
prompts: ChatItemType[];
maxTokens: number;
}) => {
const formatPrompts = prompts.map((item) => ({
obj: item.obj,
value: item.value
// .replace(/[\u3000\u3001\uff01-\uff5e\u3002]/g, ' ') // 中文标点改空格
.replace(/\n+/g, '\n') // 连续空行
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
.trim()
}));
let chats: ChatItemType[] = [];
let systemPrompt: ChatItemType | null = null;
// System 词保留
if (formatPrompts[0]?.obj === 'SYSTEM') {
systemPrompt = formatPrompts.shift() as ChatItemType;
}
// 格式化文本内容成 chatgpt 格式
// role map
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
let rawTextLen = 0;
const formatPrompts = prompts.map((item) => {
const val = simplifyStr(item.value);
rawTextLen += val.length;
return {
role: map[item.obj],
content: val
};
});
// 长度太小时,不需要进行 token 截断
if (rawTextLen < maxTokens * 0.5) {
return formatPrompts;
}
// 根据 tokens 截断内容
const chats: ChatCompletionRequestMessage[] = [];
let systemPrompt: ChatCompletionRequestMessage | null = null;
// System 词保留
if (formatPrompts[0]?.role === 'system') {
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
}
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
chats.unshift(formatPrompts[i]);
messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({
role: map[item.obj],
content: item.value
}));
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
const tokens = countChatTokens({
model,
@@ -147,7 +158,7 @@ export const openaiChatFilter = ({
return messages;
};
/* system 内容截断 */
/* system 内容截断. 相似度从高到低 */
export const systemPromptFilter = ({
model,
prompts,
@@ -161,7 +172,7 @@ export const systemPromptFilter = ({
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i].replace(/\n+/g, '\n');
const prompt = simplifyStr(prompts[i]);
splitText += `${prompt}\n`;
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
@@ -170,5 +181,5 @@ export const systemPromptFilter = ({
}
}
return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n');
return splitText.slice(0, splitText.length - 1);
};