From 5bf95bd846bbcbd5af0b3b832ed8038c8580c20a Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Wed, 17 May 2023 22:09:38 +0800 Subject: [PATCH] feat: model related kb --- src/api/model.ts | 1 - src/constants/model.ts | 11 +- src/constants/theme.ts | 4 +- src/pages/api/chat/chat.ts | 2 +- src/pages/api/chat/shareChat/chat.ts | 2 +- src/pages/api/model/update.ts | 7 +- src/pages/api/openapi/chat/chat.ts | 2 +- .../detail/components/ModelEditForm.tsx | 131 +++++++++++++++--- src/pages/model/components/detail/index.tsx | 65 +++++---- src/service/models/model.ts | 35 +---- src/service/plugins/searchKb.ts | 2 +- src/service/utils/auth.ts | 9 +- src/service/utils/chat/openai.ts | 3 +- src/store/user.ts | 8 +- src/types/model.d.ts | 4 +- src/types/mongoSchema.d.ts | 9 +- 16 files changed, 178 insertions(+), 117 deletions(-) diff --git a/src/api/model.ts b/src/api/model.ts index c9a6042b3..a5a3443bd 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -2,7 +2,6 @@ import { GET, POST, DELETE, PUT } from './request'; import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelUpdateParams, ShareModelItem } from '@/types/model'; import { RequestPaging } from '../types/index'; -import { Obj2Query } from '@/utils/tools'; import type { ModelListResponse } from './response/model'; /** diff --git a/src/constants/model.ts b/src/constants/model.ts index 2ee2b579e..6bb9e6d20 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -1,6 +1,6 @@ import { getSystemModelList } from '@/api/system'; -import type { ModelSchema } from '@/types/mongoSchema'; import type { ShareChatEditType } from '@/types/model'; +import type { ModelSchema } from '@/types/mongoSchema'; export const embeddingModel = 'text-embedding-ada-002'; export type EmbeddingModelType = 'text-embedding-ada-002'; @@ -142,7 +142,7 @@ export const defaultModel: ModelSchema = { status: ModelStatusEnum.pending, updateTime: Date.now(), chat: { - useKb: false, + relatedKbs: [], searchMode: ModelVectorSearchModeEnum.hightSimilarity, systemPrompt: '', temperature: 0, @@ -153,13 +153,6 @@ export const defaultModel: ModelSchema = { isShareDetail: false, intro: '', collection: 0 - }, - security: { - domain: ['*'], - contextMaxLen: 1, - contentMaxLen: 1, - expiredTime: 9999, - maxLoadAmount: 1 } }; diff --git a/src/constants/theme.ts b/src/constants/theme.ts index 46d290d05..aa1eb4b51 100644 --- a/src/constants/theme.ts +++ b/src/constants/theme.ts @@ -1,6 +1,6 @@ import { extendTheme, defineStyleConfig, ComponentStyleConfig } from '@chakra-ui/react'; // @ts-ignore -import { modalAnatomy, switchAnatomy, selectAnatomy } from '@chakra-ui/anatomy'; +import { modalAnatomy, switchAnatomy, selectAnatomy, checkboxAnatomy } from '@chakra-ui/anatomy'; // @ts-ignore import { createMultiStyleConfigHelpers } from '@chakra-ui/styled-system'; @@ -11,6 +11,8 @@ const { definePartsStyle: switchPart, defineMultiStyleConfig: switchMultiStyle } createMultiStyleConfigHelpers(switchAnatomy.keys); const { definePartsStyle: selectPart, defineMultiStyleConfig: selectMultiStyle } = createMultiStyleConfigHelpers(selectAnatomy.keys); +const { definePartsStyle: checkboxPart, defineMultiStyleConfig: checkboxMultiStyle } = + createMultiStyleConfigHelpers(checkboxAnatomy.keys); // modal 弹窗 const ModalTheme = defineMultiStyleConfig({ diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 703266ed5..6acc914ff 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -54,7 +54,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const prompts = [...content, prompt]; // 使用了知识库搜索 - if (model.chat.useKb) { + if (model.chat.relatedKbs.length > 0) { const { code, searchPrompts } = await searchKb({ userOpenAiKey, prompts, diff --git a/src/pages/api/chat/shareChat/chat.ts b/src/pages/api/chat/shareChat/chat.ts index 3014c565c..1602c81c3 100644 --- a/src/pages/api/chat/shareChat/chat.ts +++ b/src/pages/api/chat/shareChat/chat.ts @@ -50,7 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 使用了知识库搜索 - if (model.chat.useKb) { + if (model.chat.relatedKbs.length > 0) { const { code, searchPrompts } = await searchKb({ userOpenAiKey, prompts, diff --git a/src/pages/api/model/update.ts b/src/pages/api/model/update.ts index 41e301000..e4101ce46 100644 --- a/src/pages/api/model/update.ts +++ b/src/pages/api/model/update.ts @@ -9,10 +9,10 @@ import { authModel } from '@/service/utils/auth'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { name, avatar, chat, share, security } = req.body as ModelUpdateParams; + const { name, avatar, chat, share } = req.body as ModelUpdateParams; const { modelId } = req.query as { modelId: string }; - if (!name || !chat || !security || !modelId) { + if (!name || !chat || !modelId) { throw new Error('参数错误'); } @@ -38,8 +38,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< chat, 'share.isShare': share.isShare, 'share.isShareDetail': share.isShareDetail, - 'share.intro': share.intro, - security + 'share.intro': share.intro } ); diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index a49891ad2..e4dbcc07f 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -70,7 +70,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 使用了知识库搜索 - if (model.chat.useKb) { + if (model.chat.relatedKbs.length > 0) { const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22; const { code, searchPrompts } = await searchKb({ diff --git a/src/pages/model/components/detail/components/ModelEditForm.tsx b/src/pages/model/components/detail/components/ModelEditForm.tsx index 922e44efb..c41812016 100644 --- a/src/pages/model/components/detail/components/ModelEditForm.tsx +++ b/src/pages/model/components/detail/components/ModelEditForm.tsx @@ -32,11 +32,9 @@ import { Th, Td, TableContainer, - IconButton + Checkbox } from '@chakra-ui/react'; -import { DeleteIcon } from '@chakra-ui/icons'; import { QuestionOutlineIcon } from '@chakra-ui/icons'; -import type { ModelSchema } from '@/types/mongoSchema'; import { useForm, UseFormReturn } from 'react-hook-form'; import { ChatModelMap, ModelVectorSearchModeMap, getChatModelList } from '@/constants/model'; import { formatPrice } from '@/utils/user'; @@ -49,9 +47,12 @@ import { getShareChatList, createShareChat, delShareChatById } from '@/api/chat' import { useRouter } from 'next/router'; import { defaultShareChat } from '@/constants/model'; import type { ShareChatEditType } from '@/types/model'; +import type { ModelSchema } from '@/types/mongoSchema'; import { formatTimeToChatTime, useCopyData } from '@/utils/tools'; import MyIcon from '@/components/Icon'; import { useGlobalStore } from '@/store/global'; +import { useUserStore } from '@/store/user'; +import type { KbItemType } from '@/types/plugin'; const ModelEditForm = ({ formHooks, @@ -62,10 +63,11 @@ const ModelEditForm = ({ isOwner: boolean; handleDelModel: () => void; }) => { - const { toast } = useToast(); const { modelId } = useRouter().query as { modelId: string }; - const { setLoading } = useGlobalStore(); const [refresh, setRefresh] = useState(false); + const { toast } = useToast(); + const { setLoading } = useGlobalStore(); + const { loadKbList } = useUserStore(); const { openConfirm, ConfirmChild } = useConfirm({ content: '确认删除该AI助手?' @@ -86,6 +88,11 @@ const ModelEditForm = ({ onOpen: onOpenCreateShareChat, onClose: onCloseCreateShareChat } = useDisclosure(); + const { + isOpen: isOpenKbSelect, + onOpen: onOpenKbSelect, + onClose: onCloseKbSelect + } = useDisclosure(); const { File, onOpen: onOpenSelectFile } = useSelectFile({ fileType: '.jpg,.png', multiple: false @@ -153,11 +160,41 @@ ${e.password ? `密码为: ${e.password}` : ''}`; ] ); + // format share used token const formatTokens = (tokens: number) => { if (tokens < 10000) return tokens; return `${(tokens / 10000).toFixed(2)}万`; }; + // init kb select list + const { data: kbList = [] } = useQuery(['loadKbList'], () => loadKbList()); + const RenderSelectedKbList = useCallback(() => { + const kbs = getValues('chat.relatedKbs').map((id) => kbList.find((kb) => kb._id === id)); + + return ( + <> + {kbs.map((item) => + item ? ( + + + + + {item.name} + + + + ) : null + )} + + ); + }, [getValues, kbList]); + return ( <> {/* basic info */} @@ -292,18 +329,7 @@ ${e.password ? `密码为: ${e.password}` : ''}`; - - 知识库搜索 - { - setValue('chat.useKb', !getValues('chat.useKb')); - setRefresh(!refresh); - }} - /> - - {getValues('chat.useKb') && ( + {getValues('chat.relatedKbs').length > 0 && ( 搜索模式  @@ -339,7 +365,9 @@ ${e.password ? `密码为: ${e.password}` : ''}`; 分享设置 - 模型分享: + + 模型分享: + @@ -350,7 +378,8 @@ ${e.password ? `密码为: ${e.password}` : ''}`; setRefresh(!refresh); }} /> - + + 分享模型细节: @@ -376,8 +405,22 @@ ${e.password ? `密码为: ${e.password}` : ''}`; - {/* shareChat */} + + 关联的知识库 + + + + + {/* shareChat */} + 免登录聊天窗口 @@ -410,7 +453,7 @@ ${e.password ? `密码为: ${e.password}` : ''}`; 最大上下文 tokens消耗 最后使用时间 - + 操作 @@ -539,6 +582,52 @@ ${e.password ? `密码为: ${e.password}` : ''}`; + {/* select kb modal */} + + + + 选择关联的知识库 + + + {kbList.map((item) => ( + + { + const ids = getValues('chat.relatedKbs'); + // toggle to true + if (e.target.checked) { + setValue('chat.relatedKbs', ids.concat(item._id)); + } else { + const i = ids.findIndex((id) => id === item._id); + ids.splice(i, 1); + setValue('chat.relatedKbs', ids); + } + setRefresh(!refresh); + }} + > + + + + {item.name} + + + + + ))} + + + + + + + diff --git a/src/pages/model/components/detail/index.tsx b/src/pages/model/components/detail/index.tsx index 772f476b5..a9138b4e7 100644 --- a/src/pages/model/components/detail/index.tsx +++ b/src/pages/model/components/detail/index.tsx @@ -2,10 +2,9 @@ import React, { useCallback, useState, useMemo, useEffect } from 'react'; import { useRouter } from 'next/router'; import { delModelById, putModelById } from '@/api/model'; import type { ModelSchema } from '@/types/mongoSchema'; -import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; +import { Card, Box, Flex, Button, Grid } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useForm } from 'react-hook-form'; -import { formatModelStatus } from '@/constants/model'; import { useQuery } from '@tanstack/react-query'; import { useUserStore } from '@/store/user'; import { useLoading } from '@/hooks/useLoading'; @@ -18,7 +17,7 @@ const ModelDetail = ({ modelId, isPc }: { modelId: string; isPc: boolean }) => { const { Loading, setIsLoading } = useLoading(); const [btnLoading, setBtnLoading] = useState(false); - const formHooks = useForm({ + const formHooks = useForm({ defaultValues: modelDetail }); @@ -84,13 +83,9 @@ const ModelDetail = ({ modelId, isPc }: { modelId: string; isPc: boolean }) => { name: data.name, avatar: data.avatar || '/icon/logo.png', chat: data.chat, - share: data.share, - security: data.security - }); - toast({ - title: '更新成功', - status: 'success' + share: data.share }); + refreshModel.updateModelDetail(data); } catch (err: any) { toast({ @@ -120,18 +115,16 @@ const ModelDetail = ({ modelId, isPc }: { modelId: string; isPc: boolean }) => { }); }, [formHooks.formState.errors, toast]); + const saveUpdateModel = useCallback( + () => formHooks.handleSubmit(saveSubmitSuccess, saveSubmitError)(), + [formHooks, saveSubmitError, saveSubmitSuccess] + ); + useEffect(() => { - router.prefetch('/chat'); - - window.onbeforeunload = (e) => { - e.preventDefault(); - e.returnValue = '内容已修改,确认离开页面吗?'; - }; - return () => { - window.onbeforeunload = null; + saveUpdateModel(); }; - }, [router]); + }, []); return canRead ? ( @@ -142,13 +135,6 @@ const ModelDetail = ({ modelId, isPc }: { modelId: string; isPc: boolean }) => { {modelDetail.name} - - {formatModelStatus[modelDetail.status].text} - @@ -169,9 +166,6 @@ const ModelDetail = ({ modelId, isPc }: { modelId: string; isPc: boolean }) => { {modelDetail.name} - - {formatModelStatus[modelDetail.status].text} - diff --git a/src/service/models/model.ts b/src/service/models/model.ts index e415ad49d..20a8caf46 100644 --- a/src/service/models/model.ts +++ b/src/service/models/model.ts @@ -31,10 +31,10 @@ const ModelSchema = new Schema({ default: () => new Date() }, chat: { - useKb: { - // use knowledge base to search - type: Boolean, - default: false + relatedKbs: { + type: [Schema.Types.ObjectId], + ref: 'kb', + default: [] }, searchMode: { // knowledge base search mode @@ -79,33 +79,6 @@ const ModelSchema = new Schema({ type: Number, default: 0 } - }, - security: { - type: { - domain: { - type: [String], - default: ['*'] - }, - contextMaxLen: { - type: Number, - default: 20 - }, - contentMaxLen: { - type: Number, - default: 4000 - }, - expiredTime: { - type: Number, - default: 1, - set: (val: number) => val * (60 * 60 * 1000) - }, - maxLoadAmount: { - // 负数代表不限制 - type: Number, - default: -1 - } - }, - default: {} } }); diff --git a/src/service/plugins/searchKb.ts b/src/service/plugins/searchKb.ts index 00159f9c5..0c5c7442f 100644 --- a/src/service/plugins/searchKb.ts +++ b/src/service/plugins/searchKb.ts @@ -48,7 +48,7 @@ export const searchKb = async ({ where: [ ['status', ModelDataStatusEnum.ready], 'AND', - ['model_id', model._id], + `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, 'AND', `vector <=> '[${promptVector}]' < ${similarity}` ], diff --git a/src/service/utils/auth.ts b/src/service/utils/auth.ts index 0e86563f2..71c4f6b24 100644 --- a/src/service/utils/auth.ts +++ b/src/service/utils/auth.ts @@ -34,6 +34,13 @@ export const authToken = (req: NextApiRequest): Promise => { }); }; +export const getOpenAiKey = () => { + // 纯字符串类型 + const keys = process.env.OPENAIKEY?.split(',') || []; + const i = Math.floor(Math.random() * keys.length); + return keys[i] || (process.env.OPENAIKEY as string); +}; + /* 获取 api 请求的 key */ export const getApiKey = async ({ model, @@ -52,7 +59,7 @@ export const getApiKey = async ({ const keyMap = { [OpenAiChatEnum.GPT35]: { userOpenAiKey: user.openaiKey || '', - systemAuthKey: process.env.OPENAIKEY as string + systemAuthKey: getOpenAiKey() as string }, [OpenAiChatEnum.GPT4]: { userOpenAiKey: user.openaiKey || '', diff --git a/src/service/utils/chat/openai.ts b/src/service/utils/chat/openai.ts index 4ee556b93..98b2347ae 100644 --- a/src/service/utils/chat/openai.ts +++ b/src/service/utils/chat/openai.ts @@ -7,6 +7,7 @@ import { adaptChatItem_openAI } from '@/utils/chat/openai'; import { modelToolMap } from '@/utils/chat'; import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index'; import { ChatRoleEnum } from '@/constants/chat'; +import { getOpenAiKey } from '../auth'; export const getOpenAIApi = (apiKey: string) => { const configuration = new Configuration({ @@ -27,7 +28,7 @@ export const openaiCreateEmbedding = async ({ userId: string; textArr: string[]; }) => { - const systemAuthKey = process.env.OPENAIKEY as string; + const systemAuthKey = getOpenAiKey(); // 获取 chatAPI const chatAPI = getOpenAIApi(userOpenAiKey || systemAuthKey); diff --git a/src/store/user.ts b/src/store/user.ts index 62ccbc2e1..c09b1145a 100644 --- a/src/store/user.ts +++ b/src/store/user.ts @@ -2,7 +2,6 @@ import { create } from 'zustand'; import { devtools, persist } from 'zustand/middleware'; import { immer } from 'zustand/middleware/immer'; import type { UserType, UserUpdateParams } from '@/types/user'; -import type { ModelSchema } from '@/types/mongoSchema'; import { getMyModels, getModelById } from '@/api/model'; import { formatPrice } from '@/utils/user'; import { getTokenLogin } from '@/api/user'; @@ -11,6 +10,7 @@ import { ModelListItemType } from '@/types/model'; import { KbItemType } from '@/types/plugin'; import { getKbList } from '@/api/plugins/kb'; import { defaultKbDetail } from '@/constants/kb'; +import type { ModelSchema } from '@/types/mongoSchema'; type State = { userInfo: UserType | null; @@ -34,7 +34,7 @@ type State = { lastKbId: string; setLastKbId: (id: string) => void; myKbList: KbItemType[]; - loadKbList: (init?: boolean) => Promise; + loadKbList: (init?: boolean) => Promise; KbDetail: KbItemType; getKbDetail: (id: string) => KbItemType; }; @@ -123,12 +123,12 @@ export const useUserStore = create()( }, myKbList: [], async loadKbList(init = false) { - if (get().myKbList.length > 0 && !init) return null; + if (get().myKbList.length > 0 && !init) return get().myKbList; const res = await getKbList(); set((state) => { state.myKbList = res; }); - return null; + return res; }, KbDetail: defaultKbDetail, getKbDetail(id: string) { diff --git a/src/types/model.d.ts b/src/types/model.d.ts index 26b22aed7..b5402924e 100644 --- a/src/types/model.d.ts +++ b/src/types/model.d.ts @@ -1,5 +1,6 @@ import { ModelStatusEnum } from '@/constants/model'; -import type { ModelSchema } from './mongoSchema'; +import type { ModelSchema, kbSchema } from './mongoSchema'; +import { ChatModelType, ModelVectorSearchModeEnum } from '@/constants/model'; export type ModelListItemType = { _id: string; @@ -13,7 +14,6 @@ export interface ModelUpdateParams { avatar: string; chat: ModelSchema['chat']; share: ModelSchema['share']; - security: ModelSchema['security']; } export interface ShareModelItem { diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index 5d9e54fe1..a2f8de210 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -38,7 +38,7 @@ export interface ModelSchema { status: `${ModelStatusEnum}`; updateTime: number; chat: { - useKb: boolean; + relatedKbs: string[]; searchMode: `${ModelVectorSearchModeEnum}`; systemPrompt: string; temperature: number; @@ -50,13 +50,6 @@ export interface ModelSchema { intro: string; collection: number; }; - security: { - domain: string[]; - contextMaxLen: number; - contentMaxLen: number; - expiredTime: number; - maxLoadAmount: number; - }; } export interface ModelPopulate extends ModelSchema {