From 3ea2cf1dcbe4d656c6dd131c7954c7f86d3d2c64 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Fri, 21 Apr 2023 23:30:26 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20chat=E4=B8=8A=E4=B8=8B=E6=96=87?= =?UTF-8?q?=E6=88=AA=E6=96=AD;QA=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/constants/model.ts | 2 +- src/pages/api/chat/chatGpt.ts | 25 +++++++++++++------ src/pages/api/chat/vectorGpt.ts | 16 ++++++------ src/pages/api/model/data/splitData.ts | 4 +-- src/pages/api/openapi/chat/chatGpt.ts | 13 +++++++--- src/pages/api/openapi/chat/vectorGpt.ts | 22 +++++++++++----- src/pages/chat/index.tsx | 2 +- .../detail/components/SelectFileModal.tsx | 2 +- .../detail/components/SelectUrlModal.tsx | 2 +- src/service/events/generateQA.ts | 9 ++++--- 10 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/constants/model.ts b/src/constants/model.ts index 669be9850..d700485a4 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -35,7 +35,7 @@ export const modelList: ModelConstantsData[] = [ model: ChatModelNameEnum.GPT35, trainName: '', maxToken: 4000, - contextMaxToken: 7500, + contextMaxToken: 7000, maxTemperature: 1.5, price: 3 }, diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 213f5e2f3..b3bdf5f46 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -61,7 +61,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 控制在 tokens 数量,防止超出 - // const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); // 格式化文本内容成 chatgpt 格式 const map = { @@ -69,14 +69,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) AI: ChatCompletionRequestMessageRoleEnum.Assistant, SYSTEM: ChatCompletionRequestMessageRoleEnum.System }; - const formatPrompts: ChatCompletionRequestMessage[] = prompts.map((item: ChatItemType) => ({ - role: map[item.obj], - content: item.value - })); - // console.log(formatPrompts); + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); + // 计算温度 const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); - + // console.log({ + // model: model.service.chatModel, + // temperature: temperature, + // // max_tokens: modelConstantsData.maxToken, + // messages: formatPrompts, + // frequency_penalty: 0.5, // 越大,重复内容越少 + // presence_penalty: -0.5, // 越大,越容易出现新内容 + // stream: true, + // stop: ['.!?。'] + // }); // 获取 chatAPI const chatAPI = getOpenAIApi(userApiKey || systemKey); // 发出请求 diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 0524765a9..0018fec97 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; import { authChat } from '@/service/utils/chat'; -import { httpsAgent, systemPromptFilter } from '@/service/utils/tools'; +import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; @@ -79,7 +79,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) `vector <=> '[${promptVector}]' < ${similarity}` ], order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: 30 + limit: 20 }); const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); @@ -116,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 控制在 tokens 数量,防止超出 - // const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); // 格式化文本内容成 chatgpt 格式 const map = { @@ -124,10 +124,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) AI: ChatCompletionRequestMessageRoleEnum.Assistant, SYSTEM: ChatCompletionRequestMessageRoleEnum.System }; - const formatPrompts: ChatCompletionRequestMessage[] = prompts.map((item: ChatItemType) => ({ - role: map[item.obj], - content: item.value - })); + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); // console.log(formatPrompts); // 计算温度 const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts index 392114820..423af8363 100644 --- a/src/pages/api/model/data/splitData.ts +++ b/src/pages/api/model/data/splitData.ts @@ -41,11 +41,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const tokens = encode(splitText + chunk).length; if (tokens >= 4000) { // 超过 4000,不要这块内容 - textList.push(splitText); + splitText && textList.push(splitText); splitText = chunk; } else if (tokens >= 3000) { // 超过 3000,取内容 - textList.push(splitText + chunk); + splitText && textList.push(splitText + chunk); splitText = ''; } else { //没超过 3000,继续添加 diff --git a/src/pages/api/openapi/chat/chatGpt.ts b/src/pages/api/openapi/chat/chatGpt.ts index eabd17e4d..e96d37eea 100644 --- a/src/pages/api/openapi/chat/chatGpt.ts +++ b/src/pages/api/openapi/chat/chatGpt.ts @@ -74,16 +74,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); } + // 控制在 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + // 格式化文本内容成 chatgpt 格式 const map = { Human: ChatCompletionRequestMessageRoleEnum.User, AI: ChatCompletionRequestMessageRoleEnum.Assistant, SYSTEM: ChatCompletionRequestMessageRoleEnum.System }; - const formatPrompts: ChatCompletionRequestMessage[] = prompts.map((item: ChatItemType) => ({ - role: map[item.obj], - content: item.value - })); + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); // console.log(formatPrompts); // 计算温度 const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts index 20a5ae6fc..9dbc291c7 100644 --- a/src/pages/api/openapi/chat/vectorGpt.ts +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -1,6 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { httpsAgent, systemPromptFilter, authOpenApiKey } from '@/service/utils/tools'; +import { + httpsAgent, + systemPromptFilter, + authOpenApiKey, + openaiChatFilter +} from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; @@ -93,7 +98,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) `vector <=> '[${promptVector}]' < ${similarity}` ], order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: 30 + limit: 20 }); const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); @@ -134,16 +139,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); } + // 控制在 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + // 格式化文本内容成 chatgpt 格式 const map = { Human: ChatCompletionRequestMessageRoleEnum.User, AI: ChatCompletionRequestMessageRoleEnum.Assistant, SYSTEM: ChatCompletionRequestMessageRoleEnum.System }; - const formatPrompts: ChatCompletionRequestMessage[] = prompts.map((item: ChatItemType) => ({ - role: map[item.obj], - content: item.value - })); + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); // console.log(formatPrompts); // 计算温度 const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index a731990b1..ec8f83e63 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -88,7 +88,7 @@ const Chat = ({ chatId }: { chatId: string }) => { throttle(() => { if (!ChatBox.current) return; const isBottom = - ChatBox.current.scrollTop + ChatBox.current.clientHeight + 80 >= + ChatBox.current.scrollTop + ChatBox.current.clientHeight + 150 >= ChatBox.current.scrollHeight; isBottom && scrollToBottom('auto'); diff --git a/src/pages/model/detail/components/SelectFileModal.tsx b/src/pages/model/detail/components/SelectFileModal.tsx index 88b9e0bce..336b8b07c 100644 --- a/src/pages/model/detail/components/SelectFileModal.tsx +++ b/src/pages/model/detail/components/SelectFileModal.tsx @@ -86,7 +86,7 @@ const SelectFileModal = ({ await postModelDataSplitData({ modelId, text: fileText.replace(/\\n/g, '\n').replace(/\n+/g, '\n'), - prompt: `下面是${prompt || '一段长文本'}` + prompt: `下面是"${prompt || '一段长文本'}"` }); toast({ title: '导入数据成功,需要一段拆解和训练', diff --git a/src/pages/model/detail/components/SelectUrlModal.tsx b/src/pages/model/detail/components/SelectUrlModal.tsx index 1d011c0f2..8377ec1b2 100644 --- a/src/pages/model/detail/components/SelectUrlModal.tsx +++ b/src/pages/model/detail/components/SelectUrlModal.tsx @@ -45,7 +45,7 @@ const SelectUrlModal = ({ await postModelDataSplitData({ modelId, text: webText, - prompt: `下面是${prompt || '一段长文本'}` + prompt: `下面是"${prompt || '一段长文本'}"` }); toast({ title: '导入数据成功,需要一段拆解和训练', diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index d6ae0c7e4..7a9532d81 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -69,9 +69,9 @@ export async function generateQA(next = false): Promise { const chatAPI = getOpenAIApi(userApiKey || systemKey); const systemPrompt: ChatCompletionRequestMessage = { role: 'system', - content: `${ - dataItem.prompt || '下面是一段长文本' - },请从中提取出5至30个问题和答案,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n` + content: `你是出题官.${ + dataItem.prompt || '下面是"一段长文本"' + },从中选出5至20个题目和答案,题目包含问答题,计算题,代码题等.答案要详细.按格式返回: Q1:\nA1:\nQ2:\nA2:\n` }; // 请求 chatgpt 获取回答 @@ -114,7 +114,8 @@ export async function generateQA(next = false): Promise { }; }) .catch((err) => { - console.log('QA 拆分错误', err); + console.log('QA拆分错误'); + console.log(err.response?.status, err.response?.statusText, err.response?.data); return Promise.reject(err); }) )