From af35e17fdb71b17acd7c20aca8408c2c571855e4 Mon Sep 17 00:00:00 2001
From: archer <545436317@qq.com>
Date: Wed, 22 Mar 2023 22:09:40 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=B8=AD=E6=96=AD?=
=?UTF-8?q?=E6=B5=81.fix:=20=E4=B8=AD=E6=96=AD=E6=B5=81=E5=AF=BC=E8=87=B4?=
=?UTF-8?q?=E7=9A=84=E6=9C=8D=E5=8A=A1=E7=AB=AF=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/api/fetch.ts | 6 ++--
src/pages/api/chat/chatGpt.ts | 52 +++++++++++++++++++++++++----------
src/pages/chat/index.tsx | 15 ++++++++--
src/service/events/bill.ts | 48 +++++++++++++++++---------------
4 files changed, 81 insertions(+), 40 deletions(-)
diff --git a/src/api/fetch.ts b/src/api/fetch.ts
index 89ebba06a..92949af6a 100644
--- a/src/api/fetch.ts
+++ b/src/api/fetch.ts
@@ -3,8 +3,9 @@ interface StreamFetchProps {
url: string;
data: any;
onMessage: (text: string) => void;
+ abortSignal: AbortController;
}
-export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) =>
+export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
new Promise(async (resolve, reject) => {
try {
const res = await fetch(url, {
@@ -13,7 +14,8 @@ export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) =>
'Content-Type': 'application/json',
Authorization: getToken() || ''
},
- body: JSON.stringify(data)
+ body: JSON.stringify(data),
+ signal: abortSignal.signal
});
const reader = res.body?.getReader();
if (!reader) return;
diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts
index f041d5e36..6e5a32043 100644
--- a/src/pages/api/chat/chatGpt.ts
+++ b/src/pages/api/chat/chatGpt.ts
@@ -13,13 +13,27 @@ import { pushBill } from '@/service/events/bill';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
- const { chatId, prompt } = req.body as {
- prompt: ChatItemType;
- chatId: string;
- };
- const { authorization } = req.headers;
+ let step = 0; // step=1时,表示开始了流响应
+ const stream = new PassThrough();
+ stream.on('error', () => {
+ console.log('error: ', 'stream error');
+ stream.destroy();
+ });
+ res.on('close', () => {
+ console.log('stream request close');
+ stream.destroy();
+ });
+ res.on('error', () => {
+ console.log('error: ', 'request error');
+ stream.destroy();
+ });
try {
+ const { chatId, prompt } = req.body as {
+ prompt: ChatItemType;
+ chatId: string;
+ };
+ const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
@@ -92,10 +106,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
+ step = 1;
let responseContent = '';
- const pass = new PassThrough();
- pass.pipe(res);
+ stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
@@ -107,7 +121,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!content) return;
responseContent += content;
// console.log('content:', content)
- pass.push(content.replace(/\n/g, '
'));
+ stream.push(content.replace(/\n/g, '
'));
} catch (error) {
error;
}
@@ -116,13 +130,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
+ if (stream.destroyed) {
+ // 流被中断了,直接忽略后面的内容
+ break;
+ }
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
- pass.push(null);
+ stream.push(null);
const promptsLen = formatPrompts.reduce((sum, item) => sum + item.content.length, 0);
console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsLen}`);
@@ -135,10 +153,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
textLen: promptsLen + responseContent.length
});
} catch (err: any) {
- res.status(500);
- jsonRes(res, {
- code: 500,
- error: err
- });
+ if (step === 1) {
+ console.log('error,结束');
+ // 直接结束流
+ stream.destroy();
+ } else {
+ res.status(500);
+ jsonRes(res, {
+ code: 500,
+ error: err
+ });
+ }
}
}
diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx
index ab1b53733..16b56412c 100644
--- a/src/pages/chat/index.tsx
+++ b/src/pages/chat/index.tsx
@@ -1,4 +1,4 @@
-import React, { useCallback, useState, useRef, useMemo } from 'react';
+import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router';
import Image from 'next/image';
import {
@@ -88,6 +88,16 @@ const Chat = ({ chatId }: { chatId: string }) => {
}, [chatData]);
const { pushChatHistory } = useChatStore();
+ // 中断请求
+ const controller = useRef(new AbortController());
+ useEffect(() => {
+ controller.current = new AbortController();
+ return () => {
+ console.log('close========');
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ controller.current?.abort();
+ };
+ }, [chatId]);
// 滚动到底部
const scrollToBottom = useCallback(() => {
@@ -212,7 +222,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
};
})
}));
- }
+ },
+ abortSignal: controller.current
});
// 保存对话信息
diff --git a/src/service/events/bill.ts b/src/service/events/bill.ts
index db2ec25ee..d75e9a955 100644
--- a/src/service/events/bill.ts
+++ b/src/service/events/bill.ts
@@ -12,30 +12,34 @@ export const pushBill = async ({
chatId: string;
textLen: number;
}) => {
- await connectToDatabase();
-
- const modelItem = ModelList.find((item) => item.model === modelName);
-
- if (!modelItem) return;
-
- const price = modelItem.price * textLen;
-
- let billId;
try {
- // 插入 Bill 记录
- const res = await Bill.create({
- userId,
- chatId,
- textLen,
- price
- });
- billId = res._id;
+ await connectToDatabase();
- // 扣费
- await User.findByIdAndUpdate(userId, {
- $inc: { balance: -price }
- });
+ const modelItem = ModelList.find((item) => item.model === modelName);
+
+ if (!modelItem) return;
+
+ const price = modelItem.price * textLen;
+
+ let billId;
+ try {
+ // 插入 Bill 记录
+ const res = await Bill.create({
+ userId,
+ chatId,
+ textLen,
+ price
+ });
+ billId = res._id;
+
+ // 扣费
+ await User.findByIdAndUpdate(userId, {
+ $inc: { balance: -price }
+ });
+ } catch (error) {
+ billId && Bill.findByIdAndDelete(billId);
+ }
} catch (error) {
- billId && Bill.findByIdAndDelete(billId);
+ console.log(error);
}
};