diff --git a/src/api/chat.ts b/src/api/chat.ts index 2d49c674f..8532db3ce 100644 --- a/src/api/chat.ts +++ b/src/api/chat.ts @@ -36,28 +36,10 @@ export const postGPT3SendPrompt = ({ }); /** - * 预发 prompt 进行存储 + * 存储一轮对话 */ -export const postChatGptPrompt = ({ - prompt, - windowId, - chatId -}: { - prompt: ChatSiteItemType; - windowId: string; - chatId: string; -}) => - POST(`/chat/preChat`, { - windowId, - prompt: { - obj: prompt.obj, - value: prompt.value - }, - chatId - }); -/* 获取 Chat 的 Event 对象,进行持续通信 */ -export const getChatGPTSendEvent = (chatId: string, windowId: string) => - new EventSource(`/api/chat/chatGpt?chatId=${chatId}&windowId=${windowId}&date=${Date.now()}`); +export const postSaveChat = (data: { windowId: string; prompts: ChatItemType[] }) => + POST('/chat/saveChat', data); /** * 删除最后一句 diff --git a/src/api/fetch.ts b/src/api/fetch.ts new file mode 100644 index 000000000..23d6b3538 --- /dev/null +++ b/src/api/fetch.ts @@ -0,0 +1,47 @@ +interface StreamFetchProps { + url: string; + data: any; + onMessage: (text: string) => void; +} +export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) => + new Promise(async (resolve, reject) => { + try { + const res = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + }); + const reader = res.body?.getReader(); + if (!reader) return; + const decoder = new TextDecoder(); + let responseText = ''; + + const read = async () => { + const { done, value } = await reader?.read(); + if (done) { + if (res.status === 200) { + resolve(responseText); + } else { + try { + const parseError = JSON.parse(responseText); + reject(parseError?.message || '请求异常'); + } catch (err) { + reject('请求异常'); + } + } + + return; + } + const text = decoder.decode(value).replace(//g, '\n'); + res.status === 200 && onMessage(text); + responseText += text; + read(); + }; + read(); + } catch (err: any) { + console.log(err, '===='); + reject(typeof err === 'string' ? err : err?.message || '请求异常'); + } + }); diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index abf876219..9fb3fdb8d 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -6,26 +6,19 @@ import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { httpsAgent } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; -import { openaiError } from '@/service/errorCode'; +import { jsonRes } from '@/service/response'; +import { PassThrough } from 'stream'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { - res.setHeader('Content-Type', 'text/event-stream;charset-utf-8'); - res.setHeader('Access-Control-Allow-Origin', '*'); - res.setHeader('X-Accel-Buffering', 'no'); - res.setHeader('Cache-Control', 'no-cache, no-transform'); - - res.on('close', () => { - res.end(); - }); - req.on('error', () => { - res.end(); - }); - - const { chatId, windowId } = req.query as { chatId: string; windowId: string }; + const { chatId, windowId, prompt } = req.body as { + prompt: ChatItemType; + windowId: string; + chatId: string; + }; try { - if (!windowId || !chatId) { + if (!windowId || !chatId || !prompt) { throw new Error('缺少参数'); } @@ -35,15 +28,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const model: ModelType = chat.modelId; - const map = { - Human: ChatCompletionRequestMessageRoleEnum.User, - AI: ChatCompletionRequestMessageRoleEnum.Assistant, - SYSTEM: ChatCompletionRequestMessageRoleEnum.System - }; // 读取对话内容 const prompts: ChatItemType[] = (await ChatWindow.findById(windowId)).content; + prompts.push(prompt); - // 长度过滤 + // 上下文长度过滤 const maxContext = model.security.contextMaxLen; const filterPrompts = prompts.length > maxContext + 2 @@ -51,6 +40,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) : prompts.slice(0, prompts.length); // 格式化文本内容 + const map = { + Human: ChatCompletionRequestMessageRoleEnum.User, + AI: ChatCompletionRequestMessageRoleEnum.Assistant, + SYSTEM: ChatCompletionRequestMessageRoleEnum.System + }; const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( (item: ChatItemType) => ({ role: map[item.obj], @@ -62,9 +56,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) role: ChatCompletionRequestMessageRoleEnum.System, content: '如果你想返回代码,请务必声明代码的类型!并且在代码块前加一个换行符。' }); + // 获取 chatAPI const chatAPI = getOpenAIApi(userApiKey); - + // 发出请求 const chatResponse = await chatAPI.createChatCompletion( { model: model.service.chatModel, @@ -84,58 +79,35 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) 'response success' ); - let AIResponse = ''; + // 创建响应流 + const pass = new PassThrough(); + pass.pipe(res); - // 解析数据 - const decoder = new TextDecoder(); const onParse = async (event: ParsedEvent | ReconnectInterval) => { - if (event.type === 'event') { - const data = event.data; - if (data === '[DONE]') { - // 存入库 - await ChatWindow.findByIdAndUpdate(windowId, { - $push: { - content: { - obj: 'AI', - value: AIResponse - } - }, - updateTime: Date.now() - }); - res.write('event: done\ndata: \n\n'); - return; - } - try { - const json = JSON.parse(data); - const content: string = json?.choices?.[0].delta.content || '\n'; - // console.log('content:', content) - res.write(`event: responseData\ndata: ${content.replace(/\n/g, '
')}\n\n`); - AIResponse += content; - } catch (error) { - error; - } + if (event.type !== 'event') return; + const data = event.data; + if (data === '[DONE]') return; + try { + const json = JSON.parse(data); + const content: string = json?.choices?.[0].delta.content || ''; + if (!content) return; + // console.log('content:', content) + pass.push(content.replace(/\n/g, '
')); + } catch (error) { + error; } }; for await (const chunk of chatResponse.data as any) { const parser = createParser(onParse); - parser.feed(decoder.decode(chunk)); + parser.feed(decodeURIComponent(chunk)); } + pass.push(null); } catch (err: any) { - console.log('error->', err?.response, '==='); - let errorText = 'OpenAI 服务器访问超时'; - if (err.code === 'ECONNRESET' || err?.response?.status === 502) { - errorText = '服务器代理出错'; - } else if (err?.response?.statusText && openaiError[err.response.statusText]) { - errorText = openaiError[err.response.statusText]; - } - console.log('error->', errorText); - res.write(`event: serviceError\ndata: ${errorText}\n\n`); - // 删除最一条数据库记录, 也就是预发送的那一条 - await ChatWindow.findByIdAndUpdate(windowId, { - $pop: { content: 1 }, - updateTime: Date.now() + res.status(500); + jsonRes(res, { + code: 500, + error: err }); - res.end(); } } diff --git a/src/pages/api/chat/preChat.ts b/src/pages/api/chat/saveChat.ts similarity index 52% rename from src/pages/api/chat/preChat.ts rename to src/pages/api/chat/saveChat.ts index 7576c1cd5..07acac593 100644 --- a/src/pages/api/chat/preChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -2,34 +2,31 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { ChatItemType } from '@/types/chat'; import { connectToDatabase, ChatWindow } from '@/service/mongo'; -import type { ModelType } from '@/types/model'; -import { authChat } from '@/service/utils/chat'; -/* 聊天预请求,存储聊天内容 */ +/* 聊天内容存存储 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { windowId, prompt, chatId } = req.body as { + const { windowId, prompts } = req.body as { windowId: string; - prompt: ChatItemType; - chatId: string; + prompts: ChatItemType[]; }; - if (!windowId || !prompt || !chatId) { + if (!windowId || !prompts) { throw new Error('缺少参数'); } await connectToDatabase(); - const { chat } = await authChat(chatId); - - // 长度校验 - const model: ModelType = chat.modelId; - if (prompt.value.length > model.security.contentMaxLen) { - throw new Error('输入内容超长'); - } - + // 存入库 await ChatWindow.findByIdAndUpdate(windowId, { - $push: { content: prompt }, + $push: { + content: { + $each: prompts.map((item) => ({ + obj: item.obj, + value: item.value + })) + } + }, updateTime: Date.now() }); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index faf12bf7f..fbfdfee11 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -1,13 +1,7 @@ import React, { useCallback, useState, useRef, useMemo } from 'react'; import { useRouter } from 'next/router'; import Image from 'next/image'; -import { - getInitChatSiteInfo, - postGPT3SendPrompt, - getChatGPTSendEvent, - postChatGptPrompt, - delLastMessage -} from '@/api/chat'; +import { getInitChatSiteInfo, postGPT3SendPrompt, delLastMessage, postSaveChat } from '@/api/chat'; import { ChatSiteItemType, ChatSiteType } from '@/types/chat'; import { Textarea, Box, Flex, Button } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; @@ -17,6 +11,7 @@ import { useQuery } from '@tanstack/react-query'; import { OpenAiModelEnum } from '@/constants/model'; import dynamic from 'next/dynamic'; import { useGlobalStore } from '@/store/global'; +import { streamFetch } from '@/api/fetch'; const Markdown = dynamic(() => import('@/components/Markdown')); @@ -128,69 +123,64 @@ const Chat = ({ chatId, windowId }: { chatId: string; windowId?: string }) => { const chatGPTPrompt = useCallback( async (newChatList: ChatSiteItemType[]) => { if (!windowId) return; - /* 预请求,把消息存入库 */ - await postChatGptPrompt({ - windowId, - prompt: newChatList[newChatList.length - 1], - chatId - }); - - return new Promise((resolve, reject) => { - const event = getChatGPTSendEvent(chatId, windowId); - // 30s 收不到消息就报错 - let timer = setTimeout(() => { - event.close(); - reject('服务器超时'); - }, 30000); - event.addEventListener('responseData', ({ data }) => { - /* 重置定时器 */ - clearTimeout(timer); - timer = setTimeout(() => { - event.close(); - reject('服务器超时'); - }, 30000); - - const msg = data.replace(//g, '\n'); + const prompt = { + obj: newChatList[newChatList.length - 1].obj, + value: newChatList[newChatList.length - 1].value + }; + // 流请求,获取数据 + const res = await streamFetch({ + url: '/api/chat/chatGpt', + data: { + windowId, + prompt, + chatId + }, + onMessage: (text: string) => { setChatList((state) => state.map((item, index) => { if (index !== state.length - 1) return item; return { ...item, - value: item.value + msg + value: item.value + text }; }) ); - }); - event.addEventListener('done', () => { - console.log('done'); - clearTimeout(timer); - event.close(); - setChatList((state) => - state.map((item, index) => { - if (index !== state.length - 1) return item; - return { - ...item, - status: 'finish' - }; - }) - ); - resolve(''); - }); - event.addEventListener('serviceError', ({ data: err }) => { - clearTimeout(timer); - event.close(); - console.log('error->', err, '==='); - reject(typeof err === 'string' ? err : '对话出现不知名错误~'); - }); - event.onerror = (err) => { - clearTimeout(timer); - event.close(); - console.log('error->', err); - reject(typeof err === 'string' ? err : '对话出现不知名错误~'); - }; + } }); + + // 保存对话信息 + try { + await postSaveChat({ + windowId, + prompts: [ + prompt, + { + obj: 'AI', + value: res as string + } + ] + }); + } catch (err) { + toast({ + title: '存储对话出现异常, 继续对话会导致上下文丢失,请刷新页面', + status: 'warning', + duration: 3000, + isClosable: true + }); + } + + // 设置完成状态 + setChatList((state) => + state.map((item, index) => { + if (index !== state.length - 1) return item; + return { + ...item, + status: 'finish' + }; + }) + ); }, - [chatId, windowId] + [chatId, toast, windowId] ); /** diff --git a/src/service/errorCode.ts b/src/service/errorCode.ts index 9f1db2926..607a1116a 100644 --- a/src/service/errorCode.ts +++ b/src/service/errorCode.ts @@ -4,3 +4,7 @@ export const openaiError: Record = { rate_limit_reached: '同时访问用户过多,请稍后再试', 'Bad Request': '上下文太多了,请重开对话~' }; +export const proxyError: Record = { + ECONNABORTED: true, + ECONNRESET: true +}; diff --git a/src/service/response.ts b/src/service/response.ts index 58201b365..fa04a0bfc 100644 --- a/src/service/response.ts +++ b/src/service/response.ts @@ -1,5 +1,5 @@ import { NextApiResponse } from 'next'; -import { openaiError } from './errorCode'; +import { openaiError, proxyError } from './errorCode'; export interface ResponseType { code: number; @@ -23,12 +23,13 @@ export const jsonRes = ( msg = error?.message || '请求错误'; if (typeof error === 'string') { msg = error; - } else if (error?.response?.data?.message in openaiError) { - msg = openaiError[error?.response?.data?.message]; + } else if (proxyError[error?.code]) { + msg = '服务器代理出错'; + } else if (openaiError[error?.response?.statusText]) { + msg = openaiError[error.response.statusText]; } - console.log('error->', error); - console.log('error->', msg); + console.log('error->', error.code, error?.response?.statusText, msg); } res.json({