From c3ccbcb7f63debf08fe6af2a0eb21dfeb1429146 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 28 Mar 2023 00:36:26 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E8=BE=93=E5=85=A5=E8=B6=85=E9=95=BF?= =?UTF-8?q?=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/constants/model.ts | 3 +++ src/pages/api/chat/chatGpt.ts | 10 +++++----- src/pages/api/chat/delChatRecordByIndex.ts | 22 ++++++++++++++++++++-- src/pages/api/chat/gpt3.ts | 2 +- src/pages/chat/index.tsx | 20 ++++++++++++++++---- 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/constants/model.ts b/src/constants/model.ts index f7a2c8ca3..ef1a4800d 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -12,6 +12,7 @@ export type ModelConstantsData = { model: `${ChatModelNameEnum}`; trainName: string; // 空字符串代表不能训练 maxToken: number; + contextMaxToken: number; maxTemperature: number; trainedMaxToken: number; // 训练后最大多少tokens price: number; // 多少钱 / 1token,单位: 0.00001元 @@ -24,6 +25,7 @@ export const modelList: ModelConstantsData[] = [ model: ChatModelNameEnum.GPT35, trainName: '', maxToken: 4000, + contextMaxToken: 7500, trainedMaxToken: 2000, maxTemperature: 2, price: 3 @@ -34,6 +36,7 @@ export const modelList: ModelConstantsData[] = [ // model: ChatModelNameEnum.GPT3, // trainName: 'davinci', // maxToken: 4000, + // contextMaxToken: 7500, // trainedMaxToken: 2000, // maxTemperature: 2, // price: 30 diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 871bc4736..8a877bbaf 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -44,6 +44,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); const model: ModelSchema = chat.modelId; + const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); + if (!modelConstantsData) { + throw new Error('模型异常,请用 chatgpt 模型'); + } // 读取对话内容 const prompts = [...chat.content, prompt]; @@ -57,7 +61,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 控制在 tokens 数量,防止超出 - const filterPrompts = openaiChatFilter(prompts, 7500); + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); // 格式化文本内容成 chatgpt 格式 const map = { @@ -73,10 +77,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) ); // console.log(formatPrompts); // 计算温度 - const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); - if (!modelConstantsData) { - throw new Error('模型异常'); - } const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); // 获取 chatAPI diff --git a/src/pages/api/chat/delChatRecordByIndex.ts b/src/pages/api/chat/delChatRecordByIndex.ts index d0596f014..f3c5278da 100644 --- a/src/pages/api/chat/delChatRecordByIndex.ts +++ b/src/pages/api/chat/delChatRecordByIndex.ts @@ -9,13 +9,31 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (!chatId || !index) { throw new Error('缺少参数'); } - console.log(index); + await connectToDatabase(); + const chatRecord = await Chat.findById(chatId); + + if (!chatRecord) { + throw new Error('找不到对话'); + } + + // 重新计算 index,跳过已经被删除的内容 + let unDeleteIndex = +index; + let deletedIndex = 0; + for (deletedIndex = 0; deletedIndex < chatRecord.content.length; deletedIndex++) { + if (!chatRecord.content[deletedIndex].deleted) { + unDeleteIndex--; + if (unDeleteIndex < 0) { + break; + } + } + } + // 删除最一条数据库记录, 也就是预发送的那一条 await Chat.findByIdAndUpdate(chatId, { $set: { - [`content.${index}.deleted`]: true, + [`content.${deletedIndex}.deleted`]: true, updateTime: Date.now() } }); diff --git a/src/pages/api/chat/gpt3.ts b/src/pages/api/chat/gpt3.ts index 2bcbdc588..769687f48 100644 --- a/src/pages/api/chat/gpt3.ts +++ b/src/pages/api/chat/gpt3.ts @@ -62,7 +62,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 计算温度 const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { - throw new Error('模型异常'); + throw new Error('模型异常,请用 chatgpt 模型'); } const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 7a71b83cf..7f95376b3 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -37,6 +37,7 @@ import SlideBar from './components/SlideBar'; import Empty from './components/Empty'; import Icon from '@/components/Icon'; import { encode } from 'gpt-token-utils'; +import { modelList } from '@/constants/model'; const Markdown = dynamic(() => import('@/components/Markdown')); @@ -200,6 +201,18 @@ const Chat = ({ chatId }: { chatId: string }) => { return; } + // 长度校验 + const tokens = encode(val).length; + const model = modelList.find((item) => item.model === chatData.modelName); + + if (model && tokens >= model.maxToken) { + toast({ + title: '单次输入超出 4000 tokens', + status: 'warning' + }); + return; + } + const newChatList: ChatSiteItemType[] = [ ...chatData.history, { @@ -252,15 +265,14 @@ const Chat = ({ chatId }: { chatId: string }) => { } }, [ inputVal, - chatData?.modelId, - chatData.history, + chatData, isChatting, resetInputVal, scrollToBottom, + toast, gptChatPrompt, pushChatHistory, - chatId, - toast + chatId ]); // 删除一句话