From af35e17fdb71b17acd7c20aca8408c2c571855e4 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Wed, 22 Mar 2023 22:09:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=B8=AD=E6=96=AD?= =?UTF-8?q?=E6=B5=81.fix:=20=E4=B8=AD=E6=96=AD=E6=B5=81=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84=E6=9C=8D=E5=8A=A1=E7=AB=AF=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/fetch.ts | 6 ++-- src/pages/api/chat/chatGpt.ts | 52 +++++++++++++++++++++++++---------- src/pages/chat/index.tsx | 15 ++++++++-- src/service/events/bill.ts | 48 +++++++++++++++++--------------- 4 files changed, 81 insertions(+), 40 deletions(-) diff --git a/src/api/fetch.ts b/src/api/fetch.ts index 89ebba06a..92949af6a 100644 --- a/src/api/fetch.ts +++ b/src/api/fetch.ts @@ -3,8 +3,9 @@ interface StreamFetchProps { url: string; data: any; onMessage: (text: string) => void; + abortSignal: AbortController; } -export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) => +export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => new Promise(async (resolve, reject) => { try { const res = await fetch(url, { @@ -13,7 +14,8 @@ export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) => 'Content-Type': 'application/json', Authorization: getToken() || '' }, - body: JSON.stringify(data) + body: JSON.stringify(data), + signal: abortSignal.signal }); const reader = res.body?.getReader(); if (!reader) return; diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index f041d5e36..6e5a32043 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -13,13 +13,27 @@ import { pushBill } from '@/service/events/bill'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { - const { chatId, prompt } = req.body as { - prompt: ChatItemType; - chatId: string; - }; - const { authorization } = req.headers; + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + console.log('stream request close'); + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); try { + const { chatId, prompt } = req.body as { + prompt: ChatItemType; + chatId: string; + }; + const { authorization } = req.headers; if (!chatId || !prompt) { throw new Error('缺少参数'); } @@ -92,10 +106,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) res.setHeader('Access-Control-Allow-Origin', '*'); res.setHeader('X-Accel-Buffering', 'no'); res.setHeader('Cache-Control', 'no-cache, no-transform'); + step = 1; let responseContent = ''; - const pass = new PassThrough(); - pass.pipe(res); + stream.pipe(res); const onParse = async (event: ParsedEvent | ReconnectInterval) => { if (event.type !== 'event') return; @@ -107,7 +121,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (!content) return; responseContent += content; // console.log('content:', content) - pass.push(content.replace(/\n/g, '
')); + stream.push(content.replace(/\n/g, '
')); } catch (error) { error; } @@ -116,13 +130,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const decoder = new TextDecoder(); try { for await (const chunk of chatResponse.data as any) { + if (stream.destroyed) { + // 流被中断了,直接忽略后面的内容 + break; + } const parser = createParser(onParse); parser.feed(decoder.decode(chunk)); } } catch (error) { console.log('pipe error', error); } - pass.push(null); + stream.push(null); const promptsLen = formatPrompts.reduce((sum, item) => sum + item.content.length, 0); console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsLen}`); @@ -135,10 +153,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) textLen: promptsLen + responseContent.length }); } catch (err: any) { - res.status(500); - jsonRes(res, { - code: 500, - error: err - }); + if (step === 1) { + console.log('error,结束'); + // 直接结束流 + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } } } diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index ab1b53733..16b56412c 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useState, useRef, useMemo } from 'react'; +import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import { useRouter } from 'next/router'; import Image from 'next/image'; import { @@ -88,6 +88,16 @@ const Chat = ({ chatId }: { chatId: string }) => { }, [chatData]); const { pushChatHistory } = useChatStore(); + // 中断请求 + const controller = useRef(new AbortController()); + useEffect(() => { + controller.current = new AbortController(); + return () => { + console.log('close========'); + // eslint-disable-next-line react-hooks/exhaustive-deps + controller.current?.abort(); + }; + }, [chatId]); // 滚动到底部 const scrollToBottom = useCallback(() => { @@ -212,7 +222,8 @@ const Chat = ({ chatId }: { chatId: string }) => { }; }) })); - } + }, + abortSignal: controller.current }); // 保存对话信息 diff --git a/src/service/events/bill.ts b/src/service/events/bill.ts index db2ec25ee..d75e9a955 100644 --- a/src/service/events/bill.ts +++ b/src/service/events/bill.ts @@ -12,30 +12,34 @@ export const pushBill = async ({ chatId: string; textLen: number; }) => { - await connectToDatabase(); - - const modelItem = ModelList.find((item) => item.model === modelName); - - if (!modelItem) return; - - const price = modelItem.price * textLen; - - let billId; try { - // 插入 Bill 记录 - const res = await Bill.create({ - userId, - chatId, - textLen, - price - }); - billId = res._id; + await connectToDatabase(); - // 扣费 - await User.findByIdAndUpdate(userId, { - $inc: { balance: -price } - }); + const modelItem = ModelList.find((item) => item.model === modelName); + + if (!modelItem) return; + + const price = modelItem.price * textLen; + + let billId; + try { + // 插入 Bill 记录 + const res = await Bill.create({ + userId, + chatId, + textLen, + price + }); + billId = res._id; + + // 扣费 + await User.findByIdAndUpdate(userId, { + $inc: { balance: -price } + }); + } catch (error) { + billId && Bill.findByIdAndDelete(billId); + } } catch (error) { - billId && Bill.findByIdAndDelete(billId); + console.log(error); } };