diff --git a/client/data/config.json b/client/data/config.json index 4d344c422..c7b58d8be 100644 --- a/client/data/config.json +++ b/client/data/config.json @@ -45,19 +45,17 @@ "defaultSystem": "" } ], - "QAModels": [ - { - "model": "gpt-3.5-turbo-16k", - "name": "GPT35-16k", - "maxToken": 16000, - "price": 0 - } - ], "VectorModels": [ { "model": "text-embedding-ada-002", "name": "Embedding-2", "price": 0 } - ] + ], + "QAModel": { + "model": "gpt-3.5-turbo-16k", + "name": "GPT35-16k", + "maxToken": 16000, + "price": 0 + } } diff --git a/client/src/api/plugins/kb.ts b/client/src/api/plugins/kb.ts index 137a578f1..9ef075426 100644 --- a/client/src/api/plugins/kb.ts +++ b/client/src/api/plugins/kb.ts @@ -12,20 +12,14 @@ import { } from '@/pages/api/openapi/kb/searchTest'; import { Response as KbDataItemType } from '@/pages/api/plugins/kb/data/getDataById'; import { Props as UpdateDataProps } from '@/pages/api/openapi/kb/updateData'; - -export type KbUpdateParams = { - id: string; - name: string; - tags: string; - avatar: string; -}; +import type { KbUpdateParams, CreateKbParams } from '../request/kb'; /* knowledge base */ export const getKbList = () => GET(`/plugins/kb/list`); export const getKbById = (id: string) => GET(`/plugins/kb/detail?id=${id}`); -export const postCreateKb = (data: { name: string }) => POST(`/plugins/kb/create`, data); +export const postCreateKb = (data: CreateKbParams) => POST(`/plugins/kb/create`, data); export const putKbById = (data: KbUpdateParams) => PUT(`/plugins/kb/update`, data); diff --git a/client/src/api/request/kb.d.ts b/client/src/api/request/kb.d.ts new file mode 100644 index 000000000..1ea227e7c --- /dev/null +++ b/client/src/api/request/kb.d.ts @@ -0,0 +1,12 @@ +export type KbUpdateParams = { + id: string; + name: string; + tags: string; + avatar: string; +}; +export type CreateKbParams = { + name: string; + tags: string[]; + avatar: string; + vectorModel: string; +}; diff --git a/client/src/api/user.ts b/client/src/api/user.ts index 075996d5d..812fd4e47 100644 --- a/client/src/api/user.ts +++ b/client/src/api/user.ts @@ -25,7 +25,7 @@ export const postRegister = ({ username: string; code: string; password: string; - inviterId: string; + inviterId?: string; }) => POST(`/plusApi/user/account/register`, { username, diff --git a/client/src/pages/api/openapi/kb/pushData.ts b/client/src/pages/api/openapi/kb/pushData.ts index a668637d0..50341b2af 100644 --- a/client/src/pages/api/openapi/kb/pushData.ts +++ b/client/src/pages/api/openapi/kb/pushData.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, TrainingData } from '@/service/mongo'; +import { connectToDatabase, TrainingData, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { authKb } from '@/service/utils/auth'; import { withNextCors } from '@/service/utils/tools'; @@ -14,7 +14,6 @@ export type DateItemType = { a: string; q: string; source?: string }; export type Props = { kbId: string; data: DateItemType[]; - model: string; mode: `${TrainingModeEnum}`; prompt?: string; }; @@ -30,23 +29,12 @@ const modeMaxToken = { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { kbId, data, mode, prompt, model } = req.body as Props; + const { kbId, data, mode, prompt } = req.body as Props; - if (!kbId || !Array.isArray(data) || !model) { + if (!kbId || !Array.isArray(data)) { throw new Error('缺少参数'); } - // auth model - if (mode === TrainingModeEnum.qa && !global.qaModels.find((item) => item.model === model)) { - throw new Error('不支持的 QA 拆分模型'); - } - if ( - mode === TrainingModeEnum.index && - !global.vectorModels.find((item) => item.model === model) - ) { - throw new Error('不支持的向量生成模型'); - } - await connectToDatabase(); // 凭证校验 @@ -58,8 +46,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex data, userId, mode, - prompt, - model + prompt }) }); } catch (err) { @@ -75,8 +62,7 @@ export async function pushDataToKb({ kbId, data, mode, - prompt, - model + prompt }: { userId: string } & Props): Promise { await authKb({ userId, @@ -152,17 +138,24 @@ export async function pushDataToKb({ .filter((item) => item.status === 'fulfilled') .map((item: any) => item.value); + const vectorModel = await (async () => { + if (mode === TrainingModeEnum.index) { + return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model; + } + return global.vectorModels[0].model; + })(); + // 插入记录 await TrainingData.insertMany( insertData.map((item) => ({ q: item.q, a: item.a, - model, source: item.source, userId, kbId, mode, - prompt + prompt, + vectorModel })) ); diff --git a/client/src/pages/api/plugins/kb/create.ts b/client/src/pages/api/plugins/kb/create.ts index f50c25d86..73524c30b 100644 --- a/client/src/pages/api/plugins/kb/create.ts +++ b/client/src/pages/api/plugins/kb/create.ts @@ -2,15 +2,13 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; +import type { CreateKbParams } from '@/api/request/kb'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { name, tags } = req.body as { - name: string; - tags: string[]; - }; + const { name, tags, avatar, vectorModel } = req.body as CreateKbParams; - if (!name) { + if (!name || !vectorModel) { throw new Error('缺少参数'); } @@ -22,7 +20,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const { _id } = await KB.create({ name, userId, - tags + tags, + vectorModel, + avatar }); jsonRes(res, { data: _id }); diff --git a/client/src/pages/api/plugins/kb/detail.ts b/client/src/pages/api/plugins/kb/detail.ts index 9e5a6547e..fc5920450 100644 --- a/client/src/pages/api/plugins/kb/detail.ts +++ b/client/src/pages/api/plugins/kb/detail.ts @@ -2,6 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; +import { getModel } from '@/service/utils/data'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -33,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< avatar: data.avatar, name: data.name, userId: data.userId, - model: data.model, + vectorModelName: getModel(data.vectorModel)?.name || 'Unknown', tags: data.tags.join(' ') } }); diff --git a/client/src/pages/api/plugins/kb/list.ts b/client/src/pages/api/plugins/kb/list.ts index f2cff7032..58cbfd3b4 100644 --- a/client/src/pages/api/plugins/kb/list.ts +++ b/client/src/pages/api/plugins/kb/list.ts @@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { KbListItemType } from '@/types/plugin'; +import { getModel } from '@/service/utils/data'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -15,7 +16,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< { userId }, - '_id avatar name tags' + '_id avatar name tags vectorModel' ).sort({ updateTime: -1 }); const data = await Promise.all( @@ -23,7 +24,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< _id: item._id, avatar: item.avatar, name: item.name, - tags: item.tags + tags: item.tags, + vectorModelName: getModel(item.vectorModel)?.name || 'UnKnow' })) ); diff --git a/client/src/pages/api/system/getInitData.ts b/client/src/pages/api/system/getInitData.ts index 4244e4e40..ca7efe03d 100644 --- a/client/src/pages/api/system/getInitData.ts +++ b/client/src/pages/api/system/getInitData.ts @@ -10,7 +10,7 @@ import { export type InitDateResponse = { chatModels: ChatModelItemType[]; - qaModels: QAModelItemType[]; + qaModel: QAModelItemType; vectorModels: VectorModelItemType[]; feConfigs: FeConfigsType; }; @@ -23,7 +23,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) data: { feConfigs: global.feConfigs, chatModels: global.chatModels, - qaModels: global.qaModels, + qaModel: global.qaModel, vectorModels: global.vectorModels } }); @@ -69,14 +69,13 @@ const defaultChatModels = [ price: 0 } ]; -const defaultQAModels = [ - { - model: 'gpt-3.5-turbo-16k', - name: 'GPT35-16k', - maxToken: 16000, - price: 0 - } -]; +const defaultQAModel = { + model: 'gpt-3.5-turbo-16k', + name: 'GPT35-16k', + maxToken: 16000, + price: 0 +}; + const defaultVectorModels = [ { model: 'text-embedding-ada-002', @@ -95,7 +94,7 @@ export async function getInitConfig() { global.systemEnv = res.SystemParams || defaultSystemEnv; global.feConfigs = res.FeConfig || defaultFeConfigs; global.chatModels = res.ChatModels || defaultChatModels; - global.qaModels = res.QAModels || defaultQAModels; + global.qaModel = res.QAModel || defaultQAModel; global.vectorModels = res.VectorModels || defaultVectorModels; } catch (error) { setDefaultData(); @@ -107,6 +106,6 @@ export function setDefaultData() { global.systemEnv = defaultSystemEnv; global.feConfigs = defaultFeConfigs; global.chatModels = defaultChatModels; - global.qaModels = defaultQAModels; + global.qaModel = defaultQAModel; global.vectorModels = defaultVectorModels; } diff --git a/client/src/pages/api/user/account/gitLogin.ts b/client/src/pages/api/user/account/gitLogin.ts index 5ec281764..e3f8f97f4 100644 --- a/client/src/pages/api/user/account/gitLogin.ts +++ b/client/src/pages/api/user/account/gitLogin.ts @@ -100,9 +100,8 @@ export async function registerUser({ username, avatar, password: nanoid(), - inviterId + inviterId: inviterId ? inviterId : undefined }); - console.log(response, '-=-=-='); // 根据 id 获取用户信息 const user = await User.findById(response._id); diff --git a/client/src/pages/app/list/component/CreateModal.tsx b/client/src/pages/app/list/component/CreateModal.tsx index f9b7122eb..b5bace511 100644 --- a/client/src/pages/app/list/component/CreateModal.tsx +++ b/client/src/pages/app/list/component/CreateModal.tsx @@ -101,8 +101,8 @@ const CreateModal = ({ onClose, onSuccess }: { onClose: () => void; onSuccess: ( { for (let i = 0; i < chunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, - model, data: chunks.slice(i, i + step), mode: TrainingModeEnum.index }); diff --git a/client/src/pages/kb/detail/components/Import/Csv.tsx b/client/src/pages/kb/detail/components/Import/Csv.tsx index cd47e86ab..db110b639 100644 --- a/client/src/pages/kb/detail/components/Import/Csv.tsx +++ b/client/src/pages/kb/detail/components/Import/Csv.tsx @@ -43,7 +43,6 @@ const CsvImport = ({ kbId }: { kbId: string }) => { for (let i = 0; i < chunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, - model, data: chunks.slice(i, i + step), mode: TrainingModeEnum.index }); diff --git a/client/src/pages/kb/detail/components/Import/Manual.tsx b/client/src/pages/kb/detail/components/Import/Manual.tsx index 3ca1f7eb9..ddb842930 100644 --- a/client/src/pages/kb/detail/components/Import/Manual.tsx +++ b/client/src/pages/kb/detail/components/Import/Manual.tsx @@ -1,11 +1,9 @@ -import React, { useCallback, useState } from 'react'; -import { Box, type BoxProps, Flex, Textarea, useTheme, Button } from '@chakra-ui/react'; -import MyRadio from '@/components/Radio/index'; +import React from 'react'; +import { Box, Textarea, Button } from '@chakra-ui/react'; import { useForm } from 'react-hook-form'; import { useToast } from '@/hooks/useToast'; import { useRequest } from '@/hooks/useRequest'; import { getErrText } from '@/utils/tools'; -import { vectorModelList } from '@/store/static'; import { postKbDataFromList } from '@/api/plugins/kb'; import { TrainingModeEnum } from '@/constants/plugin'; @@ -35,7 +33,6 @@ const ManualImport = ({ kbId }: { kbId: string }) => { }; const { insertLen } = await postKbDataFromList({ kbId, - model: vectorModelList[0].model, mode: TrainingModeEnum.index, data: [data] }); diff --git a/client/src/pages/kb/detail/components/Import/QA.tsx b/client/src/pages/kb/detail/components/Import/QA.tsx index 611fcd665..d3fc1c609 100644 --- a/client/src/pages/kb/detail/components/Import/QA.tsx +++ b/client/src/pages/kb/detail/components/Import/QA.tsx @@ -7,7 +7,7 @@ import { postKbDataFromList } from '@/api/plugins/kb'; import { splitText2Chunks } from '@/utils/file'; import { getErrText } from '@/utils/tools'; import { formatPrice } from '@/utils/user'; -import { qaModelList } from '@/store/static'; +import { qaModel } from '@/store/static'; import MyIcon from '@/components/Icon'; import CloseIcon from '@/components/Icon/close'; import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete'; @@ -20,9 +20,8 @@ import { useRouter } from 'next/router'; const fileExtension = '.txt, .doc, .docx, .pdf, .md'; const QAImport = ({ kbId }: { kbId: string }) => { - const model = qaModelList[0]?.model; - const unitPrice = qaModelList[0]?.price || 3; - const chunkLen = qaModelList[0].maxToken * 0.45; + const unitPrice = qaModel.price || 3; + const chunkLen = qaModel.maxToken * 0.45; const theme = useTheme(); const router = useRouter(); const { toast } = useToast(); @@ -58,7 +57,6 @@ const QAImport = ({ kbId }: { kbId: string }) => { for (let i = 0; i < chunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, - model, data: chunks.slice(i, i + step), mode: TrainingModeEnum.qa, prompt: prompt || '下面是一段长文本' diff --git a/client/src/pages/kb/detail/components/Info.tsx b/client/src/pages/kb/detail/components/Info.tsx index 998c744c2..34c316c02 100644 --- a/client/src/pages/kb/detail/components/Info.tsx +++ b/client/src/pages/kb/detail/components/Info.tsx @@ -7,7 +7,7 @@ import React, { ForwardedRef } from 'react'; import { useRouter } from 'next/router'; -import { Box, Flex, Button, FormControl, IconButton, Input, Card } from '@chakra-ui/react'; +import { Box, Flex, Button, FormControl, IconButton, Input } from '@chakra-ui/react'; import { QuestionOutlineIcon, DeleteIcon } from '@chakra-ui/icons'; import { delKbById, putKbById } from '@/api/plugins/kb'; import { useSelectFile } from '@/hooks/useSelectFile'; @@ -17,8 +17,6 @@ import { useConfirm } from '@/hooks/useConfirm'; import { UseFormReturn } from 'react-hook-form'; import { compressImg } from '@/utils/file'; import type { KbItemType } from '@/types/plugin'; -import { vectorModelList } from '@/store/static'; -import MySelect from '@/components/Select'; import Avatar from '@/components/Avatar'; import Tag from '@/components/Tag'; import MyTooltip from '@/components/MyTooltip'; @@ -138,7 +136,6 @@ const Info = ( useImperativeHandle(ref, () => ({ initInput: (tags: string) => { - console.log(tags); if (InputRef.current) { InputRef.current.value = tags; } @@ -153,20 +150,27 @@ const Info = ( {kbDetail._id} + + + 索引模型 + + {getValues('vectorModelName')} + 知识库头像 - + + + @@ -180,27 +184,9 @@ const Info = ( })} /> - - - 索引模型 - - - ({ - label: item.name, - value: item.model - }))} - onchange={(res) => { - setValue('model', res); - }} - /> - - - 分类标签 + 标签 @@ -208,6 +194,7 @@ const Info = ( { @@ -226,7 +213,6 @@ const Info = ( ))} - + + + + + + ); +}; + +export default CreateModal; diff --git a/client/src/pages/kb/list/index.tsx b/client/src/pages/kb/list/index.tsx index 506d1b9a9..fa1173730 100644 --- a/client/src/pages/kb/list/index.tsx +++ b/client/src/pages/kb/list/index.tsx @@ -1,5 +1,14 @@ import React, { useCallback } from 'react'; -import { Box, Card, Flex, Grid, useTheme, Button, IconButton } from '@chakra-ui/react'; +import { + Box, + Card, + Flex, + Grid, + useTheme, + Button, + IconButton, + useDisclosure +} from '@chakra-ui/react'; import { useRouter } from 'next/router'; import { useUserStore } from '@/store/user'; import PageContainer from '@/components/PageContainer'; @@ -7,12 +16,14 @@ import { useConfirm } from '@/hooks/useConfirm'; import { AddIcon } from '@chakra-ui/icons'; import { useQuery } from '@tanstack/react-query'; import { useToast } from '@/hooks/useToast'; -import { delKbById, postCreateKb } from '@/api/plugins/kb'; -import { useRequest } from '@/hooks/useRequest'; +import { delKbById } from '@/api/plugins/kb'; import Avatar from '@/components/Avatar'; import MyIcon from '@/components/Icon'; import Tag from '@/components/Tag'; import { serviceSideProps } from '@/utils/i18n'; +import dynamic from 'next/dynamic'; + +const CreateModal = dynamic(() => import('./component/CreateModal'), { ssr: false }); const Kb = () => { const theme = useTheme(); @@ -24,7 +35,13 @@ const Kb = () => { }); const { myKbList, loadKbList, setKbList } = useUserStore(); - useQuery(['loadKbList'], () => loadKbList()); + const { + isOpen: isOpenCreateModal, + onOpen: onOpenCreateModal, + onClose: onCloseCreateModal + } = useDisclosure(); + + const { refetch } = useQuery(['loadKbList'], () => loadKbList()); /* 点击删除 */ const onclickDelKb = useCallback( @@ -46,32 +63,13 @@ const Kb = () => { [toast, setKbList, myKbList] ); - /* create a new kb and router to it */ - const { mutate: onclickCreate, isLoading } = useRequest({ - mutationFn: async () => { - const name = `知识库${myKbList.length + 1}`; - const id = await postCreateKb({ name }); - return id; - }, - successToast: '创建成功', - errorToast: '创建知识库出现意外', - onSuccess(id) { - router.push(`/kb/detail?kbId=${id}`); - } - }); - return ( 我的知识库 - @@ -141,6 +139,10 @@ const Kb = () => { ))} + + + {kb.vectorModelName} + ))} @@ -153,6 +155,7 @@ const Kb = () => { )} + {isOpenCreateModal && } ); }; diff --git a/client/src/pages/login/components/RegisterForm.tsx b/client/src/pages/login/components/RegisterForm.tsx index 20f4a25b5..6caf616f7 100644 --- a/client/src/pages/login/components/RegisterForm.tsx +++ b/client/src/pages/login/components/RegisterForm.tsx @@ -57,7 +57,7 @@ const RegisterForm = ({ setPageType, loginSuccess }: Props) => { username, code, password, - inviterId: localStorage.getItem('inviterId') || '' + inviterId: localStorage.getItem('inviterId') || undefined }) ); toast({ diff --git a/client/src/pages/login/provider.tsx b/client/src/pages/login/provider.tsx index ef3a3e7b2..6f719d958 100644 --- a/client/src/pages/login/provider.tsx +++ b/client/src/pages/login/provider.tsx @@ -46,7 +46,7 @@ const provider = ({ code }: { code: string }) => { if (loginStore.provider === 'git') { return gitLogin({ code, - inviterId: localStorage.getItem('inviterId') || '' + inviterId: localStorage.getItem('inviterId') || undefined }); } return null; diff --git a/client/src/service/events/generateQA.ts b/client/src/service/events/generateQA.ts index 98adddfcc..2a98b1605 100644 --- a/client/src/service/events/generateQA.ts +++ b/client/src/service/events/generateQA.ts @@ -1,5 +1,5 @@ import { TrainingData } from '@/service/mongo'; -import { pushSplitDataBill } from '@/service/events/pushBill'; +import { pushQABill } from '@/service/events/pushBill'; import { pushDataToKb } from '@/pages/api/openapi/kb/pushData'; import { TrainingModeEnum } from '@/constants/plugin'; import { ERROR_ENUM } from '../errorCode'; @@ -60,14 +60,13 @@ export async function generateQA(): Promise { // 请求 chatgpt 获取回答 const response = await Promise.all( [data.q].map((text) => { - const modelTokenLimit = - chatModels.find((item) => item.model === data.model)?.contextMaxToken || 16000; + const modelTokenLimit = global.qaModel.maxToken || 16000; const messages: ChatCompletionRequestMessage[] = [ { role: 'system', - content: `你是出题人. -${data.prompt || '我会发送一段长文本'}. -从中提取出 25 个问题和答案. 答案详细完整. 按下面格式返回: + content: `你是出题人,${ + data.prompt || '我会发送一段长文本' + },请从中提取出 25 个问题和答案. 答案详细完整,并按下面格式返回: Q1: A1: Q2: @@ -88,7 +87,7 @@ A2: return chatAPI .createChatCompletion( { - model: data.model, + model: global.qaModel.model, temperature: 0.8, messages, stream: false, @@ -106,10 +105,9 @@ A2: const result = formatSplitText(answer || ''); // 格式化后的QA对 console.log(`split result length: `, result.length); // 计费 - pushSplitDataBill({ + pushQABill({ userId: data.userId, totalTokens, - model: data.model, appName: 'QA 拆分' }); return { @@ -135,7 +133,6 @@ A2: source: data.source })), userId, - model: global.vectorModels[0].model, mode: TrainingModeEnum.index }); diff --git a/client/src/service/events/generateVector.ts b/client/src/service/events/generateVector.ts index ec7b8ea42..a20c85424 100644 --- a/client/src/service/events/generateVector.ts +++ b/client/src/service/events/generateVector.ts @@ -38,7 +38,7 @@ export async function generateVector(): Promise { q: 1, a: 1, source: 1, - model: 1 + vectorModel: 1 }); // task preemption @@ -61,7 +61,7 @@ export async function generateVector(): Promise { // 生成词向量 const { vectors } = await getVector({ - model: data.model, + model: data.vectorModel, input: dataItems.map((item) => item.q), userId }); diff --git a/client/src/service/events/pushBill.ts b/client/src/service/events/pushBill.ts index 63bfdc8b1..83a2a53d6 100644 --- a/client/src/service/events/pushBill.ts +++ b/client/src/service/events/pushBill.ts @@ -76,13 +76,11 @@ export const updateShareChatBill = async ({ } }; -export const pushSplitDataBill = async ({ +export const pushQABill = async ({ userId, totalTokens, - model, appName }: { - model: string; userId: string; totalTokens: number; appName: string; @@ -95,7 +93,7 @@ export const pushSplitDataBill = async ({ await connectToDatabase(); // 获取模型单价格, 都是用 gpt35 拆分 - const unitPrice = global.chatModels.find((item) => item.model === model)?.price || 3; + const unitPrice = global.qaModel.price || 3; // 计算价格 const total = unitPrice * totalTokens; diff --git a/client/src/service/models/kb.ts b/client/src/service/models/kb.ts index f31562a58..9c1856300 100644 --- a/client/src/service/models/kb.ts +++ b/client/src/service/models/kb.ts @@ -19,7 +19,7 @@ const kbSchema = new Schema({ type: String, required: true }, - model: { + vectorModel: { type: String, required: true, default: 'text-embedding-ada-002' diff --git a/client/src/service/models/trainingData.ts b/client/src/service/models/trainingData.ts index 476ee1505..93a8582c2 100644 --- a/client/src/service/models/trainingData.ts +++ b/client/src/service/models/trainingData.ts @@ -28,9 +28,10 @@ const TrainingDataSchema = new Schema({ enum: Object.keys(TrainingTypeMap), required: true }, - model: { + vectorModel: { type: String, - required: true + required: true, + default: 'text-embedding-ada-002' }, prompt: { // qa split prompt diff --git a/client/src/service/moduleDispatch/chat/oneapi.ts b/client/src/service/moduleDispatch/chat/oneapi.ts index 4e4e7e9dd..6609c5b55 100644 --- a/client/src/service/moduleDispatch/chat/oneapi.ts +++ b/client/src/service/moduleDispatch/chat/oneapi.ts @@ -181,7 +181,7 @@ export const dispatchChatCompletion = async (props: Record): Promis tokens: totalTokens, question: userChatInput, answer: answerText, - maxToken, + maxToken: max_tokens, quoteList: filterQuoteQA, completeMessages }, @@ -237,7 +237,7 @@ function getChatMessages({ }) { const limitText = (() => { if (limitPrompt) - return `Use the provided content delimited by triple quotes to answer questions.${limitPrompt}`; + return `Use the provided content delimited by triple quotes to answer questions. ${limitPrompt}`; if (quotePrompt && !limitPrompt) { return `Use the provided content delimited by triple quotes to answer questions.Your task is to answer the question using only the provided content. If the content does not contain the information needed to answer this question then simply write: "你的问题没有在知识库中体现".`; } diff --git a/client/src/service/utils/data.ts b/client/src/service/utils/data.ts index b31992a23..6a4411484 100644 --- a/client/src/service/utils/data.ts +++ b/client/src/service/utils/data.ts @@ -4,11 +4,7 @@ export const getChatModel = (model?: string) => { export const getVectorModel = (model?: string) => { return global.vectorModels.find((item) => item.model === model); }; -export const getQAModel = (model?: string) => { - return global.qaModels.find((item) => item.model === model); -}; + export const getModel = (model?: string) => { - return [...global.chatModels, ...global.vectorModels, ...global.qaModels].find( - (item) => item.model === model - ); + return [...global.chatModels, ...global.vectorModels].find((item) => item.model === model); }; diff --git a/client/src/store/static.ts b/client/src/store/static.ts index 2b757c009..6e28194aa 100644 --- a/client/src/store/static.ts +++ b/client/src/store/static.ts @@ -9,7 +9,12 @@ import { delay } from '@/utils/tools'; import { FeConfigsType } from '@/types'; export let chatModelList: ChatModelItemType[] = []; -export let qaModelList: QAModelItemType[] = []; +export let qaModel: QAModelItemType = { + model: 'gpt-3.5-turbo-16k', + name: 'GPT35-16k', + maxToken: 16000, + price: 0 +}; export let vectorModelList: VectorModelItemType[] = []; export let feConfigs: FeConfigsType = {}; @@ -20,7 +25,7 @@ export const clientInitData = async (): Promise => { const res = await getInitData(); chatModelList = res.chatModels; - qaModelList = res.qaModels; + qaModel = res.qaModel; vectorModelList = res.vectorModels; feConfigs = res.feConfigs; diff --git a/client/src/types/index.d.ts b/client/src/types/index.d.ts index 51c302eda..bc870c1d7 100644 --- a/client/src/types/index.d.ts +++ b/client/src/types/index.d.ts @@ -51,7 +51,7 @@ declare global { var feConfigs: FeConfigsType; var systemEnv: SystemEnvType; var chatModels: ChatModelItemType[]; - var qaModels: QAModelItemType[]; + var qaModel: QAModelItemType; var vectorModels: VectorModelItemType[]; interface Window { diff --git a/client/src/types/mongoSchema.d.ts b/client/src/types/mongoSchema.d.ts index 028c3febe..723d9f6ff 100644 --- a/client/src/types/mongoSchema.d.ts +++ b/client/src/types/mongoSchema.d.ts @@ -72,7 +72,7 @@ export interface TrainingDataSchema { kbId: string; expireAt: Date; lockTime: Date; - model: string; + vectorModel: string; mode: `${TrainingModeEnum}`; prompt: string; q: string; @@ -164,7 +164,7 @@ export interface kbSchema { updateTime: Date; avatar: string; name: string; - model: string; + vectorModel: string; tags: string[]; } diff --git a/client/src/types/plugin.d.ts b/client/src/types/plugin.d.ts index f43fb795b..f1284c4cf 100644 --- a/client/src/types/plugin.d.ts +++ b/client/src/types/plugin.d.ts @@ -7,10 +7,15 @@ export type KbListItemType = { avatar: string; name: string; tags: string[]; + vectorModelName: string; }; /* kb type */ -export interface KbItemType extends kbSchema { - totalData: number; +export interface KbItemType { + _id: string; + avatar: string; + name: string; + userId: string; + vectorModelName: string; tags: string; } diff --git a/docSite/content/docs/installation/docker.md b/docSite/content/docs/installation/docker.md index 61c1e3c2c..9fc7afac9 100644 --- a/docSite/content/docs/installation/docker.md +++ b/docSite/content/docs/installation/docker.md @@ -213,14 +213,12 @@ docker-compose up -d "defaultSystem": "" } ], - "QAModels": [ - { - "model": "gpt-3.5-turbo-16k", - "name": "GPT35-16k", - "maxToken": 16000, - "price": 0 - } - ], + "QAModel": { + "model": "gpt-3.5-turbo-16k", + "name": "GPT35-16k", + "maxToken": 16000, + "price": 0 + }, "VectorModels": [ { "model": "text-embedding-ada-002", diff --git a/docSite/content/docs/installation/reference/configuration.md b/docSite/content/docs/installation/reference/configuration.md index fb7872669..3f96bea0e 100644 --- a/docSite/content/docs/installation/reference/configuration.md +++ b/docSite/content/docs/installation/reference/configuration.md @@ -96,14 +96,12 @@ weight: 751 "defaultSystem": "" } ], - "QAModels": [ - { - "model": "gpt-3.5-turbo-16k", - "name": "GPT35-16k", - "maxToken": 16000, - "price": 0 - } - ], + "QAModel": { + "model": "gpt-3.5-turbo-16k", + "name": "GPT35-16k", + "maxToken": 16000, + "price": 0 + }, "VectorModels": [ { "model": "text-embedding-ada-002", diff --git a/files/deploy/fastgpt/config.json b/files/deploy/fastgpt/config.json index 64ffebc40..b4d76a58d 100644 --- a/files/deploy/fastgpt/config.json +++ b/files/deploy/fastgpt/config.json @@ -46,19 +46,17 @@ "defaultSystem": "" } ], - "QAModels": [ - { - "model": "gpt-3.5-turbo-16k", - "name": "GPT35-16k", - "maxToken": 16000, - "price": 0 - } - ], "VectorModels": [ { "model": "text-embedding-ada-002", "name": "Embedding-2", "price": 0 } - ] + ], + "QAModel": { + "model": "gpt-3.5-turbo-16k", + "name": "GPT35-16k", + "maxToken": 16000, + "price": 0 + } }