diff --git a/src/api/chat.ts b/src/api/chat.ts index bfdf7c6ee..b788ac57f 100644 --- a/src/api/chat.ts +++ b/src/api/chat.ts @@ -2,16 +2,11 @@ import { GET, POST, DELETE } from './request'; import type { ChatItemType, ChatSiteItemType } from '@/types/chat'; import type { InitChatResponse } from './response/chat'; -/** - * 获取一个聊天框的ID - */ -export const getChatSiteId = (modelId: string) => GET(`/chat/generate?modelId=${modelId}`); - /** * 获取初始化聊天内容 */ -export const getInitChatSiteInfo = (chatId: string) => - GET(`/chat/init?chatId=${chatId}`); +export const getInitChatSiteInfo = (modelId: string, chatId: '' | string) => + GET(`/chat/init?modelId=${modelId}&chatId=${chatId}`); /** * 发送 GPT3 prompt @@ -34,8 +29,11 @@ export const postGPT3SendPrompt = ({ /** * 存储一轮对话 */ -export const postSaveChat = (data: { chatId: string; prompts: ChatItemType[] }) => - POST('/chat/saveChat', data); +export const postSaveChat = (data: { + modelId: string; + chatId: '' | string; + prompts: ChatItemType[]; +}) => POST('/chat/saveChat', data); /** * 删除一句对话 diff --git a/src/constants/model.ts b/src/constants/model.ts index d700485a4..90610fde8 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -1,4 +1,4 @@ -import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema'; +import type { ModelDataType, ModelSchema } from '@/types/mongoSchema'; export enum ModelDataStatusEnum { ready = 'ready', @@ -18,7 +18,6 @@ export const ChatModelNameMap = { }; export type ModelConstantsData = { - serviceCompany: `${ServiceName}`; name: string; model: `${ChatModelNameEnum}`; trainName: string; // 空字符串代表不能训练 @@ -30,7 +29,6 @@ export type ModelConstantsData = { export const modelList: ModelConstantsData[] = [ { - serviceCompany: 'openai', name: 'chatGPT', model: ChatModelNameEnum.GPT35, trainName: '', @@ -40,7 +38,6 @@ export const modelList: ModelConstantsData[] = [ price: 3 }, { - serviceCompany: 'openai', name: '知识库', model: ChatModelNameEnum.VECTOR_GPT, trainName: 'vector', @@ -132,7 +129,6 @@ export const defaultModel: ModelSchema = { mode: ModelVectorSearchModeEnum.hightSimilarity }, service: { - company: 'openai', trainId: '', chatModel: ChatModelNameEnum.GPT35, modelName: ChatModelNameEnum.GPT35 diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index b3bdf5f46..5c97f0019 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -1,11 +1,10 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; -import { getOpenAIApi, authChat } from '@/service/utils/chat'; +import { getOpenAIApi, authChat } from '@/service/utils/auth'; import { httpsAgent, openaiChatFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; -import type { ModelSchema } from '@/types/mongoSchema'; import { PassThrough } from 'stream'; import { modelList } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; @@ -28,29 +27,33 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); try { - const { chatId, prompt } = req.body as { + const { chatId, prompt, modelId } = req.body as { prompt: ChatItemType; - chatId: string; + modelId: string; + chatId: '' | string; }; const { authorization } = req.headers; - if (!chatId || !prompt) { + if (!modelId || !prompt) { throw new Error('缺少参数'); } await connectToDatabase(); let startTime = Date.now(); - const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); + const { model, content, userApiKey, systemKey, userId } = await authChat({ + modelId, + chatId, + authorization + }); - const model: ModelSchema = chat.modelId; const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { throw new Error('模型加载异常'); } // 读取对话内容 - const prompts = [...chat.content, prompt]; + const prompts = [...content, prompt]; // 如果有系统提示词,自动插入 if (model.systemPrompt) { diff --git a/src/pages/api/chat/generate.ts b/src/pages/api/chat/generate.ts deleted file mode 100644 index c425474a6..000000000 --- a/src/pages/api/chat/generate.ts +++ /dev/null @@ -1,54 +0,0 @@ -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { connectToDatabase, Model, Chat } from '@/service/mongo'; -import { authToken } from '@/service/utils/tools'; -import type { ModelSchema } from '@/types/mongoSchema'; - -/* 获取我的模型 */ -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - const { modelId } = req.query as { - modelId: string; - }; - const { authorization } = req.headers; - - if (!authorization) { - throw new Error('无权生成对话'); - } - - if (!modelId) { - throw new Error('缺少参数'); - } - - // 凭证校验 - const userId = await authToken(authorization); - - await connectToDatabase(); - - // 校验是否为用户的模型 - const model = await Model.findOne({ - _id: modelId, - userId - }); - - if (!model) { - throw new Error('无权使用该模型'); - } - - // 创建 chat 数据 - const response = await Chat.create({ - userId, - modelId, - content: [] - }); - - jsonRes(res, { - data: response._id // 即聊天框的 ID - }); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -} diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index 5b67ea5b9..7d119095a 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -1,9 +1,10 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, Chat } from '@/service/mongo'; -import type { ChatPopulate } from '@/types/mongoSchema'; +import { connectToDatabase, Chat, Model } from '@/service/mongo'; import type { InitChatResponse } from '@/api/response/chat'; import { authToken } from '@/service/utils/tools'; +import { ChatItemType } from '@/types/chat'; +import { authModel } from '@/service/utils/auth'; /* 初始化我的聊天框,需要身份验证 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -11,43 +12,46 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { authorization } = req.headers; const userId = await authToken(authorization); - const { chatId } = req.query as { chatId: string }; + const { modelId, chatId } = req.query as { modelId: string; chatId: '' | string }; - if (!chatId) { + if (!modelId) { throw new Error('缺少参数'); } await connectToDatabase(); - // 获取 chat 数据 - const chat = await Chat.findOne({ - _id: chatId, - userId - }).populate({ - path: 'modelId', - options: { - strictPopulate: false - } - }); + // 获取 model 数据 + const { model } = await authModel(modelId, userId); - if (!chat) { - throw new Error('聊天框不存在'); + // 历史记录 + let history: ChatItemType[] = []; + + if (chatId) { + // 获取 chat 数据 + const chat = await Chat.findOne({ + _id: chatId, + userId, + modelId + }); + + if (!chat) { + throw new Error('聊天框不存在'); + } + + // filter 被 deleted 的内容 + history = chat.content.filter((item) => item.deleted !== true); } - // filter 掉被 deleted 的内容 - chat.content = chat.content.filter((item) => item.deleted !== true); - - const model = chat.modelId; jsonRes(res, { data: { - chatId: chat._id, - modelId: model._id, + chatId: chatId || '', + modelId: modelId, name: model.name, avatar: model.avatar, intro: model.intro, modelName: model.service.modelName, chatModel: model.service.chatModel, - history: chat.content + history } }); } catch (err) { diff --git a/src/pages/api/chat/saveChat.ts b/src/pages/api/chat/saveChat.ts index 57c51ed8d..bb62e144f 100644 --- a/src/pages/api/chat/saveChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -2,34 +2,53 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { ChatItemType } from '@/types/chat'; import { connectToDatabase, Chat } from '@/service/mongo'; +import { authModel } from '@/service/utils/auth'; +import { authToken } from '@/service/utils/tools'; /* 聊天内容存存储 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { chatId, prompts } = req.body as { - chatId: string; + const { chatId, modelId, prompts } = req.body as { + chatId: '' | string; + modelId: string; prompts: ChatItemType[]; }; - if (!chatId || !prompts) { + if (!prompts) { throw new Error('缺少参数'); } + const userId = await authToken(req.headers.authorization); + await connectToDatabase(); - // 存入库 - await Chat.findByIdAndUpdate(chatId, { - $push: { - content: { - $each: prompts.map((item) => ({ - obj: item.obj, - value: item.value - })) - } - }, - updateTime: new Date() - }); + const content = prompts.map((item) => ({ + obj: item.obj, + value: item.value + })); + // 没有 chatId, 创建一个对话 + if (!chatId) { + await authModel(modelId, userId); + const { _id } = await Chat.create({ + userId, + modelId, + content + }); + return jsonRes(res, { + data: _id + }); + } else { + // 已经有记录,追加入库 + await Chat.findByIdAndUpdate(chatId, { + $push: { + content: { + $each: content + } + }, + updateTime: new Date() + }); + } jsonRes(res); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 0018fec97..81dc8063a 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; -import { authChat } from '@/service/utils/chat'; +import { authChat } from '@/service/utils/auth'; import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; @@ -35,29 +35,33 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); try { - const { chatId, prompt } = req.body as { + const { modelId, chatId, prompt } = req.body as { + modelId: string; + chatId: '' | string; prompt: ChatItemType; - chatId: string; }; const { authorization } = req.headers; - if (!chatId || !prompt) { + if (!modelId || !prompt) { throw new Error('缺少参数'); } await connectToDatabase(); let startTime = Date.now(); - const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); + const { model, content, userApiKey, systemKey, userId } = await authChat({ + modelId, + chatId, + authorization + }); - const model: ModelSchema = chat.modelId; const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { throw new Error('模型加载异常'); } // 读取对话内容 - const prompts = [...chat.content, prompt]; + const prompts = [...content, prompt]; // 获取提示词的向量 const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ diff --git a/src/pages/api/model/create.ts b/src/pages/api/model/create.ts index b7def6ae5..928d2163d 100644 --- a/src/pages/api/model/create.ts +++ b/src/pages/api/model/create.ts @@ -47,7 +47,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< userId, status: ModelStatusEnum.running, service: { - company: modelItem.serviceCompany, trainId: '', chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型 modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作 diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts index 423af8363..af9d0cf53 100644 --- a/src/pages/api/model/data/splitData.ts +++ b/src/pages/api/model/data/splitData.ts @@ -36,14 +36,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const textList: string[] = []; let splitText = ''; - /* 取 3k ~ 4K tokens 内容 */ + /* 取 2.5k ~ 3.5K tokens 内容 */ chunks.forEach((chunk) => { const tokens = encode(splitText + chunk).length; - if (tokens >= 4000) { - // 超过 4000,不要这块内容 + if (tokens >= 3500) { + // 超过 3500,不要这块内容 splitText && textList.push(splitText); splitText = chunk; - } else if (tokens >= 3000) { + } else if (tokens >= 2500) { // 超过 3000,取内容 splitText && textList.push(splitText + chunk); splitText = ''; diff --git a/src/pages/api/openapi/chat/chatGpt.ts b/src/pages/api/openapi/chat/chatGpt.ts index e96d37eea..d162c6c2a 100644 --- a/src/pages/api/openapi/chat/chatGpt.ts +++ b/src/pages/api/openapi/chat/chatGpt.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; +import { getOpenAIApi } from '@/service/utils/auth'; import { httpsAgent, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 04a4c8194..8863e1f03 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; +import { getOpenAIApi } from '@/service/utils/auth'; import { authOpenApiKey } from '@/service/utils/tools'; import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; diff --git a/src/pages/chat/components/SlideBar.tsx b/src/pages/chat/components/SlideBar.tsx index 42bd53764..d8ad06513 100644 --- a/src/pages/chat/components/SlideBar.tsx +++ b/src/pages/chat/components/SlideBar.tsx @@ -11,13 +11,6 @@ import { Flex, Divider, IconButton, - Modal, - ModalOverlay, - ModalContent, - ModalHeader, - ModalFooter, - ModalBody, - ModalCloseButton, useDisclosure, useColorMode, useColorModeValue @@ -29,8 +22,6 @@ import { useRouter } from 'next/router'; import { getToken } from '@/utils/user'; import MyIcon from '@/components/Icon'; import { useCopyData } from '@/utils/tools'; -import Markdown from '@/components/Markdown'; -import { getChatSiteId } from '@/api/chat'; import WxConcat from '@/components/WxConcat'; import { useMarkdown } from '@/hooks/useMarkdown'; @@ -42,7 +33,7 @@ const SlideBar = ({ }: { chatId: string; modelId: string; - resetChat: () => void; + resetChat: (modelId?: string, chatId?: string) => void; onClose: () => void; }) => { const router = useRouter(); @@ -86,7 +77,7 @@ const SlideBar = ({ : {})} onClick={() => { if (item.chatId === chatId) return; - router.replace(`/chat?chatId=${item.chatId}`); + resetChat(modelId, item.chatId); onClose(); }} > @@ -155,7 +146,7 @@ const SlideBar = ({ mb={4} mx={'auto'} leftIcon={} - onClick={resetChat} + onClick={() => resetChat()} > 新对话 @@ -194,7 +185,7 @@ const SlideBar = ({ : {})} onClick={async () => { if (item._id === modelId) return; - router.replace(`/chat?chatId=${await getChatSiteId(item._id)}`); + resetChat(item._id); onClose(); }} > @@ -260,49 +251,6 @@ const SlideBar = ({ /> - {/* 分享提示modal */} - - - - 分享对话 - - - - - - - - {getToken() && ( - - )} - - - - - {/* wx 联系 */} {isOpenWx && } diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index ec8f83e63..44220e597 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -1,7 +1,7 @@ import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import { useRouter } from 'next/router'; import Image from 'next/image'; -import { getInitChatSiteInfo, getChatSiteId, delChatRecordByIndex, postSaveChat } from '@/api/chat'; +import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat'; import type { InitChatResponse } from '@/api/response/chat'; import { ChatSiteItemType } from '@/types/chat'; import { @@ -41,18 +41,17 @@ interface ChatType extends InitChatResponse { history: ChatSiteItemType[]; } -const Chat = ({ chatId }: { chatId: string }) => { +const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { + const router = useRouter(); + const ChatBox = useRef(null); const TextareaDom = useRef(null); - const { toast } = useToast(); - const router = useRouter(); - // 中断请求 const controller = useRef(new AbortController()); const [chatData, setChatData] = useState({ - chatId: '', - modelId: '', + chatId, + modelId, name: '', avatar: '', intro: '', @@ -60,6 +59,7 @@ const Chat = ({ chatId }: { chatId: string }) => { modelName: '', history: [] }); // 聊天框整体数据 + const [inputVal, setInputVal] = useState(''); // 输入的内容 const isChatting = useMemo( @@ -68,6 +68,7 @@ const Chat = ({ chatId }: { chatId: string }) => { ); const { isOpen: isOpenSlider, onClose: onCloseSlider, onOpen: onOpenSlider } = useDisclosure(); + const { toast } = useToast(); const { copyData } = useCopyData(); const { isPc, media } = useScreen(); const { setLoading } = useGlobalStore(); @@ -108,19 +109,72 @@ const Chat = ({ chatId }: { chatId: string }) => { }, 100); }, []); - // 重载对话 - const resetChat = useCallback(async () => { - if (!chatData) return; - try { - router.replace(`/chat?chatId=${await getChatSiteId(chatData.modelId)}`); - } catch (error: any) { - toast({ - title: error?.message || '生成新对话失败', - status: 'warning' - }); - } - onCloseSlider(); - }, [chatData, onCloseSlider, router, toast]); + // 获取对话信息 + const loadChatInfo = useCallback( + async ({ + modelId, + chatId, + isLoading = false, + isScroll = false + }: { + modelId: string; + chatId: string; + isLoading?: boolean; + isScroll?: boolean; + }) => { + isLoading && setLoading(true); + try { + const res = await getInitChatSiteInfo(modelId, chatId); + setChatData({ + ...res, + history: res.history.map((item) => ({ + ...item, + status: 'finish' + })) + }); + if (isScroll && res.history.length > 0) { + setTimeout(() => { + scrollToBottom('auto'); + }, 2000); + } + } catch (e: any) { + toast({ + title: e?.message || '获取对话信息异常,请检查地址', + status: 'error', + isClosable: true, + duration: 5000 + }); + router.replace('/model/list'); + } + setLoading(false); + return null; + }, + [router, scrollToBottom, setLoading, toast] + ); + + // 重载新的对话 + const resetChat = useCallback( + async (modelId = chatData.modelId, chatId = '') => { + // 强制中断流 + controller.current?.abort(); + try { + router.replace(`/chat?modelId=${modelId}&chatId=${chatId}`); + loadChatInfo({ + modelId, + chatId, + isLoading: true, + isScroll: true + }); + } catch (error: any) { + toast({ + title: error?.message || '生成新对话失败', + status: 'warning' + }); + } + onCloseSlider(); + }, + [chatData.modelId, loadChatInfo, onCloseSlider, router, toast] + ); // gpt 对话 const gptChatPrompt = useCallback( @@ -132,6 +186,10 @@ const Chat = ({ chatId }: { chatId: string }) => { if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型'); + // create abort obj + const abortSignal = new AbortController(); + controller.current = abortSignal; + const prompt = { obj: prompts.obj, value: prompts.value @@ -141,7 +199,8 @@ const Chat = ({ chatId }: { chatId: string }) => { url: urlMap[chatData.modelName], data: { prompt, - chatId + chatId, + modelId }, onMessage: (text: string) => { setChatData((state) => ({ @@ -156,12 +215,14 @@ const Chat = ({ chatId }: { chatId: string }) => { })); generatingMessage(); }, - abortSignal: controller.current + abortSignal }); + let id = ''; // 保存对话信息 try { - await postSaveChat({ + id = await postSaveChat({ + modelId, chatId, prompts: [ prompt, @@ -171,6 +232,9 @@ const Chat = ({ chatId }: { chatId: string }) => { } ] }); + if (id) { + router.replace(`/chat?modelId=${modelId}&chatId=${id}`); + } } catch (err) { toast({ title: '对话出现异常, 继续对话会导致上下文丢失,请刷新页面', @@ -183,6 +247,7 @@ const Chat = ({ chatId }: { chatId: string }) => { // 设置完成状态 setChatData((state) => ({ ...state, + chatId: id || state.chatId, // 如果有 Id,说明是新创建的对话 history: state.history.map((item, index) => { if (index !== state.history.length - 1) return item; return { @@ -192,7 +257,7 @@ const Chat = ({ chatId }: { chatId: string }) => { }) })); }, - [chatData.modelName, chatId, generatingMessage, toast] + [chatData.modelName, chatId, generatingMessage, modelId, router, toast] ); /** @@ -210,7 +275,7 @@ const Chat = ({ chatId }: { chatId: string }) => { // 去除空行 const val = inputVal.trim().replace(/\n\s*/g, '\n'); - if (!chatData?.modelId || !val) { + if (!val) { toast({ title: '内容为空', status: 'warning' @@ -271,12 +336,12 @@ const Chat = ({ chatId }: { chatId: string }) => { })); } }, [ - inputVal, - chatData, isChatting, + inputVal, + chatData.history, resetInputVal, - scrollToBottom, toast, + scrollToBottom, gptChatPrompt, pushChatHistory, chatId @@ -312,50 +377,22 @@ const Chat = ({ chatId }: { chatId: string }) => { ); // 初始化聊天框 - useQuery( - ['init', chatId], - () => { - setLoading(true); - return getInitChatSiteInfo(chatId); - }, - { - onSuccess(res) { - setChatData({ - ...res, - history: res.history.map((item) => ({ - ...item, - status: 'finish' - })) - }); - if (res.history.length > 0) { - setTimeout(() => { - scrollToBottom('auto'); - }, 2000); - } - }, - onError(e: any) { - toast({ - title: e?.message || '初始化异常,请检查地址', - status: 'error', - isClosable: true, - duration: 5000 - }); - router.push('/model/list'); - }, - onSettled() { - setLoading(false); - } - } + useQuery(['init'], () => + loadChatInfo({ + modelId, + chatId, + isLoading: true, + isScroll: true + }) ); // 更新流中断对象 useEffect(() => { - controller.current = new AbortController(); return () => { // eslint-disable-next-line react-hooks/exhaustive-deps controller.current?.abort(); }; - }, [chatId]); + }, []); return ( { @@ -399,7 +436,7 @@ const Chat = ({ chatId }: { chatId: string }) => { @@ -565,9 +602,10 @@ const Chat = ({ chatId }: { chatId: string }) => { export default Chat; export async function getServerSideProps(context: any) { - const chatId = context?.query?.chatId || 'noid'; + const modelId = context?.query?.modelId || ''; + const chatId = context?.query?.chatId || ''; return { - props: { chatId } + props: { modelId, chatId } }; } diff --git a/src/pages/index.tsx b/src/pages/index.tsx index 8e35ee43b..118421f53 100644 --- a/src/pages/index.tsx +++ b/src/pages/index.tsx @@ -1,5 +1,5 @@ import React, { useEffect } from 'react'; -import { Card } from '@chakra-ui/react'; +import { Card, Box, Link } from '@chakra-ui/react'; import Markdown from '@/components/Markdown'; import { useMarkdown } from '@/hooks/useMarkdown'; import { useRouter } from 'next/router'; @@ -15,9 +15,20 @@ const Home = () => { }, [inviterId]); return ( - - - + <> + + + + + + + {/* + 浙B2-20080101 + */} + + Made by FastGpt Team. + + ); }; diff --git a/src/pages/model/detail/index.tsx b/src/pages/model/detail/index.tsx index 1516681a9..2803c85e5 100644 --- a/src/pages/model/detail/index.tsx +++ b/src/pages/model/detail/index.tsx @@ -1,7 +1,6 @@ import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import { useRouter } from 'next/router'; import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model'; -import { getChatSiteId } from '@/api/chat'; import type { ModelSchema } from '@/types/mongoSchema'; import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; @@ -70,14 +69,12 @@ const ModelDetail = ({ modelId }: { modelId: string }) => { const handlePreviewChat = useCallback(async () => { setLoading(true); try { - const chatId = await getChatSiteId(model._id); - - router.push(`/chat?chatId=${chatId}`); + router.push(`/chat?modelId=${modelId}`); } catch (err) { console.log('error->', err); } setLoading(false); - }, [setLoading, model, router]); + }, [setLoading, router, modelId]); /* 上传数据集,触发微调 */ // const startTraining = useCallback( diff --git a/src/pages/model/list/index.tsx b/src/pages/model/list/index.tsx index 6bc96a5ee..6c4d77fe5 100644 --- a/src/pages/model/list/index.tsx +++ b/src/pages/model/list/index.tsx @@ -1,6 +1,5 @@ import React, { useState, useCallback } from 'react'; import { Box, Button, Flex, Card } from '@chakra-ui/react'; -import { getChatSiteId } from '@/api/chat'; import type { ModelSchema } from '@/types/mongoSchema'; import { useRouter } from 'next/router'; import ModelTable from './components/ModelTable'; diff --git a/src/service/events/generateAbstract.ts b/src/service/events/generateAbstract.ts index beb8185ea..7ca3e322a 100644 --- a/src/service/events/generateAbstract.ts +++ b/src/service/events/generateAbstract.ts @@ -1,5 +1,5 @@ import { DataItem } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; +import { getOpenAIApi } from '@/service/utils/auth'; import { httpsAgent } from '@/service/utils/tools'; import { getOpenApiKey } from '../utils/openai'; import type { ChatCompletionRequestMessage } from 'openai'; diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index 7a9532d81..17cb54eb5 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -1,5 +1,5 @@ import { SplitData } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; +import { getOpenAIApi } from '@/service/utils/auth'; import { httpsAgent } from '@/service/utils/tools'; import { getOpenApiKey } from '../utils/openai'; import type { ChatCompletionRequestMessage } from 'openai'; diff --git a/src/service/events/pushBill.ts b/src/service/events/pushBill.ts index 227e1c92a..5d2ab78aa 100644 --- a/src/service/events/pushBill.ts +++ b/src/service/events/pushBill.ts @@ -14,7 +14,7 @@ export const pushChatBill = async ({ isPay: boolean; modelName: string; userId: string; - chatId?: string; + chatId?: '' | string; text: string; }) => { let billId; @@ -42,7 +42,7 @@ export const pushChatBill = async ({ userId, type: 'chat', modelName, - chatId, + chatId: chatId ? chatId : undefined, textLen: text.length, tokenLen: tokens, price diff --git a/src/service/models/model.ts b/src/service/models/model.ts index 1c3a00033..53506b494 100644 --- a/src/service/models/model.ts +++ b/src/service/models/model.ts @@ -53,11 +53,6 @@ const ModelSchema = new Schema({ } }, service: { - company: { - type: String, - required: true, - enum: ['openai'] - }, trainId: { // 训练时需要的 ID, 不能训练的模型没有这个值。 type: String, diff --git a/src/service/utils/auth.ts b/src/service/utils/auth.ts new file mode 100644 index 000000000..32ea80c52 --- /dev/null +++ b/src/service/utils/auth.ts @@ -0,0 +1,70 @@ +import { Configuration, OpenAIApi } from 'openai'; +import { Chat, Model } from '../mongo'; +import type { ModelSchema } from '@/types/mongoSchema'; +import { authToken } from './tools'; +import { getOpenApiKey } from './openai'; +import type { ChatItemType } from '@/types/chat'; + +export const getOpenAIApi = (apiKey: string) => { + const configuration = new Configuration({ + apiKey + }); + + return new OpenAIApi(configuration, undefined); +}; + +// 模型使用权校验 +export const authModel = async (modelId: string, userId: string) => { + // 获取 model 数据 + const model = await Model.findById(modelId); + if (!model) { + return Promise.reject('模型不存在'); + } + // 凭证校验 + if (userId !== String(model.userId)) { + return Promise.reject('无权使用该模型'); + } + return { model }; +}; + +// 获取对话校验 +export const authChat = async ({ + modelId, + chatId, + authorization +}: { + modelId: string; + chatId: '' | string; + authorization?: string; +}) => { + const userId = await authToken(authorization); + + // 获取 model 数据 + const { model } = await authModel(modelId, userId); + + // 聊天内容 + let content: ChatItemType[] = []; + + if (chatId) { + // 获取 chat 数据 + const chat = await Chat.findById(chatId); + + if (!chat) { + return Promise.reject('对话不存在'); + } + + // filter 掉被 deleted 的内容 + content = chat.content.filter((item) => item.deleted !== true); + } + + // 获取 user 的 apiKey + const { userApiKey, systemKey } = await getOpenApiKey(userId); + + return { + userApiKey, + systemKey, + content, + userId, + model + }; +}; diff --git a/src/service/utils/chat.ts b/src/service/utils/chat.ts deleted file mode 100644 index 616200a26..000000000 --- a/src/service/utils/chat.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { Configuration, OpenAIApi } from 'openai'; -import { Chat } from '../mongo'; -import type { ChatPopulate } from '@/types/mongoSchema'; -import { authToken } from './tools'; -import { getOpenApiKey } from './openai'; - -export const getOpenAIApi = (apiKey: string) => { - const configuration = new Configuration({ - apiKey - }); - - return new OpenAIApi(configuration, undefined); -}; - -export const authChat = async (chatId: string, authorization?: string) => { - // 获取 chat 数据 - const chat = await Chat.findById(chatId).populate({ - path: 'modelId', - options: { - strictPopulate: false - } - }); - - if (!chat || !chat.modelId || !chat.userId) { - return Promise.reject('模型不存在'); - } - - // 凭证校验 - const userId = await authToken(authorization); - if (userId !== String(chat.userId._id)) { - return Promise.reject('无权使用该对话'); - } - - // 获取 user 的 apiKey - const { user, userApiKey, systemKey } = await getOpenApiKey(chat.userId as unknown as string); - - // filter 掉被 deleted 的内容 - chat.content = chat.content.filter((item) => item.deleted !== true); - - return { - userApiKey, - systemKey, - chat, - userId: user._id - }; -}; diff --git a/src/service/utils/openai.ts b/src/service/utils/openai.ts index 4c0fcdf8a..753b1f1dd 100644 --- a/src/service/utils/openai.ts +++ b/src/service/utils/openai.ts @@ -1,7 +1,7 @@ import type { NextApiResponse } from 'next'; import type { PassThrough } from 'stream'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; -import { getOpenAIApi } from '@/service/utils/chat'; +import { getOpenAIApi } from '@/service/utils/auth'; import { httpsAgent } from './tools'; import { User } from '../models/user'; import { formatPrice } from '@/utils/user'; diff --git a/src/store/chat.ts b/src/store/chat.ts index 1b2a42cb7..b06147056 100644 --- a/src/store/chat.ts +++ b/src/store/chat.ts @@ -2,7 +2,6 @@ import { create } from 'zustand'; import { devtools, persist } from 'zustand/middleware'; import { immer } from 'zustand/middleware/immer'; import type { HistoryItem } from '@/types/chat'; -import { getChatSiteId } from '@/api/chat'; type Props = { chatHistory: HistoryItem[]; @@ -10,7 +9,6 @@ type Props = { updateChatHistory: (chatId: string, title: string) => void; removeChatHistoryByWindowId: (chatId: string) => void; clearHistory: () => void; - generateChatWindow: (modelId: string) => Promise; }; export const useChatStore = create()( devtools( @@ -40,9 +38,6 @@ export const useChatStore = create()( set((state) => { state.chatHistory = []; }); - }, - generateChatWindow(modelId: string) { - return getChatSiteId(modelId); } })), { diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index 3103665f0..b1cc8d100 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -7,8 +7,6 @@ import { } from '@/constants/model'; import type { DataType } from './data'; -export type ServiceName = 'openai'; - export interface UserModelSchema { _id: string; username: string; @@ -46,7 +44,6 @@ export interface ModelSchema { mode: `${ModelVectorSearchModeEnum}`; }; service: { - company: ServiceName; trainId: string; // 训练的模型,训练后就是训练的模型id chatModel: string; // 聊天时用的模型,训练后就是训练的模型 modelName: `${ChatModelNameEnum}`; // 底层模型名称,不会变 @@ -86,7 +83,6 @@ export interface ModelSplitDataSchema { export interface TrainingSchema { _id: string; - serviceName: ServiceName; tuneId: string; modelId: string; status: `${TrainingStatusEnum}`;