From 274ece1d91e741add21cf2a45f644204ce6bc625 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 25 Mar 2023 20:43:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20gpt3=E6=B5=81=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/constants/common.ts | 2 + src/pages/api/chat/chatGpt.ts | 6 +- src/pages/api/chat/gpt3.ts | 151 ++++++++++++++++----- src/pages/api/user/getPayCode.ts | 2 +- src/pages/chat/index.tsx | 47 +++---- src/pages/model/components/CreateModel.tsx | 11 +- src/pages/number/components/PayModal.tsx | 2 +- src/service/events/pushChatBill.ts | 4 +- src/service/models/model.ts | 4 +- src/service/models/user.ts | 2 +- src/service/response.ts | 2 +- src/utils/user.ts | 6 +- 12 files changed, 163 insertions(+), 76 deletions(-) diff --git a/src/constants/common.ts b/src/constants/common.ts index 112bda376..2a332ed24 100644 --- a/src/constants/common.ts +++ b/src/constants/common.ts @@ -3,6 +3,8 @@ export enum EmailTypeEnum { findPassword = 'findPassword' } +export const PRICE_SCALE = 100000; + export const introPage = ` ## 欢迎使用 Fast GPT diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 959751f09..455395eb1 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -89,6 +89,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) temperature: temperature, // max_tokens: modelConstantsData.maxToken, messages: formatPrompts, + frequency_penalty: 0.5, // 越大,重复内容越少 + presence_penalty: -0.5, // 越大,越容易出现新内容 stream: true }, { @@ -117,7 +119,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) try { const json = JSON.parse(data); const content: string = json?.choices?.[0].delta.content || ''; - if (!content) return; + if (!content || (responseContent === '' && content === '\n')) return; + responseContent += content; // console.log('content:', content) !stream.destroyed && stream.push(content.replace(/\n/g, '
')); @@ -144,7 +147,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) stream.destroy(); const promptsContent = formatPrompts.map((item) => item.content).join(''); - console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsContent.length}`); // 只有使用平台的 key 才计费 !userApiKey && pushBill({ diff --git a/src/pages/api/chat/gpt3.ts b/src/pages/api/chat/gpt3.ts index d90c3f72e..2744d64f3 100644 --- a/src/pages/api/chat/gpt3.ts +++ b/src/pages/api/chat/gpt3.ts @@ -1,20 +1,38 @@ -// Next.js API route support: https://nextjs.org/docs/api-routes/introduction import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; +import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; import { connectToDatabase } from '@/service/mongo'; import { getOpenAIApi, authChat } from '@/service/utils/chat'; -import { ChatItemType } from '@/types/chat'; import { httpsAgent } from '@/service/utils/tools'; +import { ChatItemType } from '@/types/chat'; +import { jsonRes } from '@/service/response'; +import type { ModelSchema } from '@/types/mongoSchema'; +import { PassThrough } from 'stream'; import { modelList } from '@/constants/model'; import { pushBill } from '@/service/events/pushChatBill'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - const { prompt, chatId } = req.body as { prompt: ChatItemType[]; chatId: string }; - const { authorization } = req.headers; + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); - if (!prompt || !chatId) { + try { + const { chatId, prompt } = req.body as { + prompt: ChatItemType; + chatId: string; + }; + const { authorization } = req.headers; + if (!chatId || !prompt) { throw new Error('缺少参数'); } @@ -22,13 +40,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); - const model = chat.modelId; + const model: ModelSchema = chat.modelId; - // 获取 chatAPI - const chatAPI = getOpenAIApi(userApiKey || systemKey); + // 读取对话内容 + const prompts = [...chat.content, prompt]; - // prompt处理 - const formatPrompts = prompt.map((item) => `${item.value}\n\n###\n\n`).join(''); + // 上下文长度过滤 + const maxContext = model.security.contextMaxLen; + const filterPrompts = + prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts; + + // 格式化文本内容 + const map = { + Human: 'Human', + AI: 'AI', + SYSTEM: 'SYSTEM' + }; + const formatPrompts: string[] = filterPrompts.map((item: ChatItemType) => item.value); + // 如果有系统提示词,自动插入 + if (model.systemPrompt) { + formatPrompts.unshift(`${model.systemPrompt}`); + } + + const promptText = formatPrompts.join(''); // 计算温度 const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); @@ -37,42 +71,95 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); - // 发送请求 - const response = await chatAPI.createCompletion( + // 获取 chatAPI + const chatAPI = getOpenAIApi(userApiKey || systemKey); + let startTime = Date.now(); + + // 发出请求 + const chatResponse = await chatAPI.createCompletion( { - model: model.service.modelName, - prompt: formatPrompts, + model: model.service.chatModel, temperature: temperature, - // max_tokens: modelConstantsData.maxToken, - top_p: 1, - frequency_penalty: 0, - presence_penalty: 0.6, - stop: ['###'] + prompt: promptText, + stream: true, + max_tokens: modelConstantsData.maxToken, + presence_penalty: 0, // 越大,越容易出现新内容 + frequency_penalty: 0, // 越大,重复内容越少 + stop: ['。!?.!.', ``] }, { + timeout: 40000, + responseType: 'stream', httpsAgent } ); - const responseContent = response.data.choices[0]?.text || ''; + console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); + + // 创建响应流 + res.setHeader('Content-Type', 'text/event-stream;charset-utf-8'); + res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('X-Accel-Buffering', 'no'); + res.setHeader('Cache-Control', 'no-cache, no-transform'); + step = 1; + + let responseContent = ''; + stream.pipe(res); + + const onParse = async (event: ParsedEvent | ReconnectInterval) => { + if (event.type !== 'event') return; + const data = event.data; + if (data === '[DONE]') return; + try { + const json = JSON.parse(data); + const content: string = json?.choices?.[0].text || ''; + if (!content || (responseContent === '' && content === '\n')) return; + + responseContent += content; + // console.log('content:', content); + !stream.destroyed && stream.push(content.replace(/\n/g, '
')); + } catch (error) { + error; + } + }; + + const decoder = new TextDecoder(); + try { + for await (const chunk of chatResponse.data as any) { + if (stream.destroyed) { + // 流被中断了,直接忽略后面的内容 + break; + } + const parser = createParser(onParse); + parser.feed(decoder.decode(chunk)); + } + } catch (error) { + console.log('pipe error', error); + } + // close stream + !stream.destroyed && stream.push(null); + stream.destroy(); - console.log(`responseLen: ${responseContent.length}`, `promptLen: ${formatPrompts.length}`); // 只有使用平台的 key 才计费 !userApiKey && pushBill({ modelName: model.service.modelName, userId, chatId, - text: formatPrompts + responseContent + text: promptText + responseContent }); - - jsonRes(res, { - data: responseContent - }); } catch (err: any) { - jsonRes(res, { - code: 500, - error: err - }); + // console.log(err?.response); + if (step === 1) { + // 直接结束流 + console.log('error,结束'); + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } } } diff --git a/src/pages/api/user/getPayCode.ts b/src/pages/api/user/getPayCode.ts index 5ade8a67c..57f65a94a 100644 --- a/src/pages/api/user/getPayCode.ts +++ b/src/pages/api/user/getPayCode.ts @@ -5,7 +5,7 @@ import axios from 'axios'; import { authToken } from '@/service/utils/tools'; import { customAlphabet } from 'nanoid'; import { connectToDatabase, Pay } from '@/service/mongo'; -import { PRICE_SCALE } from '@/utils/user'; +import { PRICE_SCALE } from '@/constants/common'; const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 20); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 16b56412c..d4aafc0ae 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -197,16 +197,22 @@ const Chat = ({ chatId }: { chatId: string }) => { [chatId] ); - // chatGPT - const chatGPTPrompt = useCallback( - async (newChatList: ChatSiteItemType[]) => { + // gpt 对话 + const gptChatPrompt = useCallback( + async (prompts: ChatSiteItemType) => { + const urlMap: Record = { + [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', + [ChatModelNameEnum.GPT3]: '/api/chat/gpt3' + }; + if (!urlMap[chatData.chatModel]) return Promise.reject('找不到模型'); + const prompt = { - obj: newChatList[newChatList.length - 1].obj, - value: newChatList[newChatList.length - 1].value + obj: prompts.obj, + value: prompts.value }; // 流请求,获取数据 const res = await streamFetch({ - url: '/api/chat/chatGpt', + url: urlMap[chatData.chatModel], data: { prompt, chatId @@ -240,7 +246,7 @@ const Chat = ({ chatId }: { chatId: string }) => { }); } catch (err) { toast({ - title: '存储对话出现异常, 继续对话会导致上下文丢失,请刷新页面', + title: '对话出现异常, 继续对话会导致上下文丢失,请刷新页面', status: 'warning', duration: 3000, isClosable: true @@ -259,7 +265,7 @@ const Chat = ({ chatId }: { chatId: string }) => { }) })); }, - [chatId, toast] + [chatData.chatModel, chatId, toast] ); /** @@ -272,7 +278,7 @@ const Chat = ({ chatId }: { chatId: string }) => { .trim() .split('\n') .filter((val) => val) - .join('\n\n'); + .join('\n'); if (!chatData?.modelId || !val || !ChatBox.current || isChatting) { return; } @@ -301,22 +307,8 @@ const Chat = ({ chatId }: { chatId: string }) => { resetInputVal(''); scrollToBottom(); - const fnMap: { [key: string]: any } = { - [ChatModelNameEnum.GPT35]: chatGPTPrompt, - [ChatModelNameEnum.GPT3]: gpt3ChatPrompt - }; - try { - /* 对长度进行限制 */ - const maxContext = chatData.secret.contextMaxLen; - const requestPrompt = - newChatList.length > maxContext + 1 - ? newChatList.slice(newChatList.length - maxContext - 1, -1) - : newChatList.slice(0, -1); - - if (typeof fnMap[chatData.chatModel] === 'function') { - await fnMap[chatData.chatModel](requestPrompt); - } + await gptChatPrompt(newChatList[newChatList.length - 2]); // 如果是 Human 第一次发送,插入历史记录 const humanChat = newChatList.filter((item) => item.obj === 'Human'); @@ -343,15 +335,12 @@ const Chat = ({ chatId }: { chatId: string }) => { } }, [ inputVal, - chatData.modelId, + chatData?.modelId, chatData.history, - chatData.secret.contextMaxLen, - chatData.chatModel, isChatting, resetInputVal, scrollToBottom, - chatGPTPrompt, - gpt3ChatPrompt, + gptChatPrompt, pushChatHistory, chatId, toast diff --git a/src/pages/model/components/CreateModel.tsx b/src/pages/model/components/CreateModel.tsx index e8b5c5a73..96c0e0010 100644 --- a/src/pages/model/components/CreateModel.tsx +++ b/src/pages/model/components/CreateModel.tsx @@ -34,6 +34,7 @@ const CreateModel = ({ onSuccess: Dispatch; }) => { const [requesting, setRequesting] = useState(false); + const [refresh, setRefresh] = useState(false); const toast = useToast({ duration: 2000, position: 'top' @@ -95,7 +96,10 @@ const CreateModel = ({