feat: config vector model and qa model

This commit is contained in:
archer
2023-08-25 15:00:51 +08:00
parent a9970dd694
commit 6d93059e25
35 changed files with 337 additions and 196 deletions

View File

@@ -45,19 +45,17 @@
"defaultSystem": "" "defaultSystem": ""
} }
], ],
"QAModels": [
{
"model": "gpt-3.5-turbo-16k",
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0
}
],
"VectorModels": [ "VectorModels": [
{ {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
"name": "Embedding-2", "name": "Embedding-2",
"price": 0 "price": 0
} }
] ],
"QAModel": {
"model": "gpt-3.5-turbo-16k",
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0
}
} }

View File

@@ -12,20 +12,14 @@ import {
} from '@/pages/api/openapi/kb/searchTest'; } from '@/pages/api/openapi/kb/searchTest';
import { Response as KbDataItemType } from '@/pages/api/plugins/kb/data/getDataById'; import { Response as KbDataItemType } from '@/pages/api/plugins/kb/data/getDataById';
import { Props as UpdateDataProps } from '@/pages/api/openapi/kb/updateData'; import { Props as UpdateDataProps } from '@/pages/api/openapi/kb/updateData';
import type { KbUpdateParams, CreateKbParams } from '../request/kb';
export type KbUpdateParams = {
id: string;
name: string;
tags: string;
avatar: string;
};
/* knowledge base */ /* knowledge base */
export const getKbList = () => GET<KbListItemType[]>(`/plugins/kb/list`); export const getKbList = () => GET<KbListItemType[]>(`/plugins/kb/list`);
export const getKbById = (id: string) => GET<KbItemType>(`/plugins/kb/detail?id=${id}`); export const getKbById = (id: string) => GET<KbItemType>(`/plugins/kb/detail?id=${id}`);
export const postCreateKb = (data: { name: string }) => POST<string>(`/plugins/kb/create`, data); export const postCreateKb = (data: CreateKbParams) => POST<string>(`/plugins/kb/create`, data);
export const putKbById = (data: KbUpdateParams) => PUT(`/plugins/kb/update`, data); export const putKbById = (data: KbUpdateParams) => PUT(`/plugins/kb/update`, data);

12
client/src/api/request/kb.d.ts vendored Normal file
View File

@@ -0,0 +1,12 @@
export type KbUpdateParams = {
id: string;
name: string;
tags: string;
avatar: string;
};
export type CreateKbParams = {
name: string;
tags: string[];
avatar: string;
vectorModel: string;
};

View File

@@ -25,7 +25,7 @@ export const postRegister = ({
username: string; username: string;
code: string; code: string;
password: string; password: string;
inviterId: string; inviterId?: string;
}) => }) =>
POST<ResLogin>(`/plusApi/user/account/register`, { POST<ResLogin>(`/plusApi/user/account/register`, {
username, username,

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo'; import { connectToDatabase, TrainingData, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { authKb } from '@/service/utils/auth'; import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
@@ -14,7 +14,6 @@ export type DateItemType = { a: string; q: string; source?: string };
export type Props = { export type Props = {
kbId: string; kbId: string;
data: DateItemType[]; data: DateItemType[];
model: string;
mode: `${TrainingModeEnum}`; mode: `${TrainingModeEnum}`;
prompt?: string; prompt?: string;
}; };
@@ -30,23 +29,12 @@ const modeMaxToken = {
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { kbId, data, mode, prompt, model } = req.body as Props; const { kbId, data, mode, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data) || !model) { if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
// auth model
if (mode === TrainingModeEnum.qa && !global.qaModels.find((item) => item.model === model)) {
throw new Error('不支持的 QA 拆分模型');
}
if (
mode === TrainingModeEnum.index &&
!global.vectorModels.find((item) => item.model === model)
) {
throw new Error('不支持的向量生成模型');
}
await connectToDatabase(); await connectToDatabase();
// 凭证校验 // 凭证校验
@@ -58,8 +46,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
data, data,
userId, userId,
mode, mode,
prompt, prompt
model
}) })
}); });
} catch (err) { } catch (err) {
@@ -75,8 +62,7 @@ export async function pushDataToKb({
kbId, kbId,
data, data,
mode, mode,
prompt, prompt
model
}: { userId: string } & Props): Promise<Response> { }: { userId: string } & Props): Promise<Response> {
await authKb({ await authKb({
userId, userId,
@@ -152,17 +138,24 @@ export async function pushDataToKb({
.filter((item) => item.status === 'fulfilled') .filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value); .map<DateItemType>((item: any) => item.value);
const vectorModel = await (async () => {
if (mode === TrainingModeEnum.index) {
return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model;
}
return global.vectorModels[0].model;
})();
// 插入记录 // 插入记录
await TrainingData.insertMany( await TrainingData.insertMany(
insertData.map((item) => ({ insertData.map((item) => ({
q: item.q, q: item.q,
a: item.a, a: item.a,
model,
source: item.source, source: item.source,
userId, userId,
kbId, kbId,
mode, mode,
prompt prompt,
vectorModel
})) }))
); );

View File

@@ -2,15 +2,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo'; import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import type { CreateKbParams } from '@/api/request/kb';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { name, tags } = req.body as { const { name, tags, avatar, vectorModel } = req.body as CreateKbParams;
name: string;
tags: string[];
};
if (!name) { if (!name || !vectorModel) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@@ -22,7 +20,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const { _id } = await KB.create({ const { _id } = await KB.create({
name, name,
userId, userId,
tags tags,
vectorModel,
avatar
}); });
jsonRes(res, { data: _id }); jsonRes(res, { data: _id });

View File

@@ -2,6 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo'; import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { getModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@@ -33,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
avatar: data.avatar, avatar: data.avatar,
name: data.name, name: data.name,
userId: data.userId, userId: data.userId,
model: data.model, vectorModelName: getModel(data.vectorModel)?.name || 'Unknown',
tags: data.tags.join(' ') tags: data.tags.join(' ')
} }
}); });

View File

@@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo'; import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { KbListItemType } from '@/types/plugin'; import { KbListItemType } from '@/types/plugin';
import { getModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@@ -15,7 +16,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
{ {
userId userId
}, },
'_id avatar name tags' '_id avatar name tags vectorModel'
).sort({ updateTime: -1 }); ).sort({ updateTime: -1 });
const data = await Promise.all( const data = await Promise.all(
@@ -23,7 +24,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
_id: item._id, _id: item._id,
avatar: item.avatar, avatar: item.avatar,
name: item.name, name: item.name,
tags: item.tags tags: item.tags,
vectorModelName: getModel(item.vectorModel)?.name || 'UnKnow'
})) }))
); );

View File

@@ -10,7 +10,7 @@ import {
export type InitDateResponse = { export type InitDateResponse = {
chatModels: ChatModelItemType[]; chatModels: ChatModelItemType[];
qaModels: QAModelItemType[]; qaModel: QAModelItemType;
vectorModels: VectorModelItemType[]; vectorModels: VectorModelItemType[];
feConfigs: FeConfigsType; feConfigs: FeConfigsType;
}; };
@@ -23,7 +23,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
data: { data: {
feConfigs: global.feConfigs, feConfigs: global.feConfigs,
chatModels: global.chatModels, chatModels: global.chatModels,
qaModels: global.qaModels, qaModel: global.qaModel,
vectorModels: global.vectorModels vectorModels: global.vectorModels
} }
}); });
@@ -69,14 +69,13 @@ const defaultChatModels = [
price: 0 price: 0
} }
]; ];
const defaultQAModels = [ const defaultQAModel = {
{ model: 'gpt-3.5-turbo-16k',
model: 'gpt-3.5-turbo-16k', name: 'GPT35-16k',
name: 'GPT35-16k', maxToken: 16000,
maxToken: 16000, price: 0
price: 0 };
}
];
const defaultVectorModels = [ const defaultVectorModels = [
{ {
model: 'text-embedding-ada-002', model: 'text-embedding-ada-002',
@@ -95,7 +94,7 @@ export async function getInitConfig() {
global.systemEnv = res.SystemParams || defaultSystemEnv; global.systemEnv = res.SystemParams || defaultSystemEnv;
global.feConfigs = res.FeConfig || defaultFeConfigs; global.feConfigs = res.FeConfig || defaultFeConfigs;
global.chatModels = res.ChatModels || defaultChatModels; global.chatModels = res.ChatModels || defaultChatModels;
global.qaModels = res.QAModels || defaultQAModels; global.qaModel = res.QAModel || defaultQAModel;
global.vectorModels = res.VectorModels || defaultVectorModels; global.vectorModels = res.VectorModels || defaultVectorModels;
} catch (error) { } catch (error) {
setDefaultData(); setDefaultData();
@@ -107,6 +106,6 @@ export function setDefaultData() {
global.systemEnv = defaultSystemEnv; global.systemEnv = defaultSystemEnv;
global.feConfigs = defaultFeConfigs; global.feConfigs = defaultFeConfigs;
global.chatModels = defaultChatModels; global.chatModels = defaultChatModels;
global.qaModels = defaultQAModels; global.qaModel = defaultQAModel;
global.vectorModels = defaultVectorModels; global.vectorModels = defaultVectorModels;
} }

View File

@@ -100,9 +100,8 @@ export async function registerUser({
username, username,
avatar, avatar,
password: nanoid(), password: nanoid(),
inviterId inviterId: inviterId ? inviterId : undefined
}); });
console.log(response, '-=-=-=');
// 根据 id 获取用户信息 // 根据 id 获取用户信息
const user = await User.findById(response._id); const user = await User.findById(response._id);

View File

@@ -101,8 +101,8 @@ const CreateModal = ({ onClose, onSuccess }: { onClose: () => void; onSuccess: (
<Avatar <Avatar
flexShrink={0} flexShrink={0}
src={getValues('avatar')} src={getValues('avatar')}
w={['32px', '36px']} w={['28px', '32px']}
h={['32px', '36px']} h={['28px', '32px']}
cursor={'pointer'} cursor={'pointer'}
borderRadius={'md'} borderRadius={'md'}
onClick={onOpenSelectFile} onClick={onOpenSelectFile}

View File

@@ -68,7 +68,6 @@ const ChunkImport = ({ kbId }: { kbId: string }) => {
for (let i = 0; i < chunks.length; i += step) { for (let i = 0; i < chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({ const { insertLen } = await postKbDataFromList({
kbId, kbId,
model,
data: chunks.slice(i, i + step), data: chunks.slice(i, i + step),
mode: TrainingModeEnum.index mode: TrainingModeEnum.index
}); });

View File

@@ -43,7 +43,6 @@ const CsvImport = ({ kbId }: { kbId: string }) => {
for (let i = 0; i < chunks.length; i += step) { for (let i = 0; i < chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({ const { insertLen } = await postKbDataFromList({
kbId, kbId,
model,
data: chunks.slice(i, i + step), data: chunks.slice(i, i + step),
mode: TrainingModeEnum.index mode: TrainingModeEnum.index
}); });

View File

@@ -1,11 +1,9 @@
import React, { useCallback, useState } from 'react'; import React from 'react';
import { Box, type BoxProps, Flex, Textarea, useTheme, Button } from '@chakra-ui/react'; import { Box, Textarea, Button } from '@chakra-ui/react';
import MyRadio from '@/components/Radio/index';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { useRequest } from '@/hooks/useRequest'; import { useRequest } from '@/hooks/useRequest';
import { getErrText } from '@/utils/tools'; import { getErrText } from '@/utils/tools';
import { vectorModelList } from '@/store/static';
import { postKbDataFromList } from '@/api/plugins/kb'; import { postKbDataFromList } from '@/api/plugins/kb';
import { TrainingModeEnum } from '@/constants/plugin'; import { TrainingModeEnum } from '@/constants/plugin';
@@ -35,7 +33,6 @@ const ManualImport = ({ kbId }: { kbId: string }) => {
}; };
const { insertLen } = await postKbDataFromList({ const { insertLen } = await postKbDataFromList({
kbId, kbId,
model: vectorModelList[0].model,
mode: TrainingModeEnum.index, mode: TrainingModeEnum.index,
data: [data] data: [data]
}); });

View File

@@ -7,7 +7,7 @@ import { postKbDataFromList } from '@/api/plugins/kb';
import { splitText2Chunks } from '@/utils/file'; import { splitText2Chunks } from '@/utils/file';
import { getErrText } from '@/utils/tools'; import { getErrText } from '@/utils/tools';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import { qaModelList } from '@/store/static'; import { qaModel } from '@/store/static';
import MyIcon from '@/components/Icon'; import MyIcon from '@/components/Icon';
import CloseIcon from '@/components/Icon/close'; import CloseIcon from '@/components/Icon/close';
import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete'; import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete';
@@ -20,9 +20,8 @@ import { useRouter } from 'next/router';
const fileExtension = '.txt, .doc, .docx, .pdf, .md'; const fileExtension = '.txt, .doc, .docx, .pdf, .md';
const QAImport = ({ kbId }: { kbId: string }) => { const QAImport = ({ kbId }: { kbId: string }) => {
const model = qaModelList[0]?.model; const unitPrice = qaModel.price || 3;
const unitPrice = qaModelList[0]?.price || 3; const chunkLen = qaModel.maxToken * 0.45;
const chunkLen = qaModelList[0].maxToken * 0.45;
const theme = useTheme(); const theme = useTheme();
const router = useRouter(); const router = useRouter();
const { toast } = useToast(); const { toast } = useToast();
@@ -58,7 +57,6 @@ const QAImport = ({ kbId }: { kbId: string }) => {
for (let i = 0; i < chunks.length; i += step) { for (let i = 0; i < chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({ const { insertLen } = await postKbDataFromList({
kbId, kbId,
model,
data: chunks.slice(i, i + step), data: chunks.slice(i, i + step),
mode: TrainingModeEnum.qa, mode: TrainingModeEnum.qa,
prompt: prompt || '下面是一段长文本' prompt: prompt || '下面是一段长文本'

View File

@@ -7,7 +7,7 @@ import React, {
ForwardedRef ForwardedRef
} from 'react'; } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { Box, Flex, Button, FormControl, IconButton, Input, Card } from '@chakra-ui/react'; import { Box, Flex, Button, FormControl, IconButton, Input } from '@chakra-ui/react';
import { QuestionOutlineIcon, DeleteIcon } from '@chakra-ui/icons'; import { QuestionOutlineIcon, DeleteIcon } from '@chakra-ui/icons';
import { delKbById, putKbById } from '@/api/plugins/kb'; import { delKbById, putKbById } from '@/api/plugins/kb';
import { useSelectFile } from '@/hooks/useSelectFile'; import { useSelectFile } from '@/hooks/useSelectFile';
@@ -17,8 +17,6 @@ import { useConfirm } from '@/hooks/useConfirm';
import { UseFormReturn } from 'react-hook-form'; import { UseFormReturn } from 'react-hook-form';
import { compressImg } from '@/utils/file'; import { compressImg } from '@/utils/file';
import type { KbItemType } from '@/types/plugin'; import type { KbItemType } from '@/types/plugin';
import { vectorModelList } from '@/store/static';
import MySelect from '@/components/Select';
import Avatar from '@/components/Avatar'; import Avatar from '@/components/Avatar';
import Tag from '@/components/Tag'; import Tag from '@/components/Tag';
import MyTooltip from '@/components/MyTooltip'; import MyTooltip from '@/components/MyTooltip';
@@ -138,7 +136,6 @@ const Info = (
useImperativeHandle(ref, () => ({ useImperativeHandle(ref, () => ({
initInput: (tags: string) => { initInput: (tags: string) => {
console.log(tags);
if (InputRef.current) { if (InputRef.current) {
InputRef.current.value = tags; InputRef.current.value = tags;
} }
@@ -153,20 +150,27 @@ const Info = (
</Box> </Box>
<Box flex={1}>{kbDetail._id}</Box> <Box flex={1}>{kbDetail._id}</Box>
</Flex> </Flex>
<Flex mt={8} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}>
</Box>
<Box flex={[1, '0 0 300px']}>{getValues('vectorModelName')}</Box>
</Flex>
<Flex mt={5} w={'100%'} alignItems={'center'}> <Flex mt={5} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}> <Box flex={['0 0 90px', '0 0 160px']} w={0}>
</Box> </Box>
<Box flex={[1, '0 0 300px']}> <Box flex={[1, '0 0 300px']}>
<Avatar <MyTooltip label={'点击切换头像'}>
m={'auto'} <Avatar
src={getValues('avatar')} m={'auto'}
w={['32px', '40px']} src={getValues('avatar')}
h={['32px', '40px']} w={['32px', '40px']}
cursor={'pointer'} h={['32px', '40px']}
title={'点击切换头像'} cursor={'pointer'}
onClick={onOpenSelectFile} onClick={onOpenSelectFile}
/> />
</MyTooltip>
</Box> </Box>
</Flex> </Flex>
<FormControl mt={8} w={'100%'} display={'flex'} alignItems={'center'}> <FormControl mt={8} w={'100%'} display={'flex'} alignItems={'center'}>
@@ -180,27 +184,9 @@ const Info = (
})} })}
/> />
</FormControl> </FormControl>
<Flex mt={8} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}>
</Box>
<Box flex={[1, '0 0 300px']}>
<MySelect
w={'100%'}
value={getValues('model')}
list={vectorModelList.map((item) => ({
label: item.name,
value: item.model
}))}
onchange={(res) => {
setValue('model', res);
}}
/>
</Box>
</Flex>
<Flex mt={8} alignItems={'center'} w={'100%'} flexWrap={'wrap'}> <Flex mt={8} alignItems={'center'} w={'100%'} flexWrap={'wrap'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}> <Box flex={['0 0 90px', '0 0 160px']} w={0}>
<MyTooltip label={'用空格隔开多个标签,便于搜索'} forceShow> <MyTooltip label={'用空格隔开多个标签,便于搜索'} forceShow>
<QuestionOutlineIcon ml={1} /> <QuestionOutlineIcon ml={1} />
</MyTooltip> </MyTooltip>
@@ -208,6 +194,7 @@ const Info = (
<Input <Input
flex={[1, '0 0 300px']} flex={[1, '0 0 300px']}
ref={InputRef} ref={InputRef}
defaultValue={getValues('tags')}
placeholder={'标签,使用空格分割。'} placeholder={'标签,使用空格分割。'}
maxLength={30} maxLength={30}
onChange={(e) => { onChange={(e) => {
@@ -226,7 +213,6 @@ const Info = (
))} ))}
</Flex> </Flex>
</Flex> </Flex>
<Flex mt={5} w={'100%'} alignItems={'flex-end'}> <Flex mt={5} w={'100%'} alignItems={'flex-end'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}></Box> <Box flex={['0 0 90px', '0 0 160px']} w={0}></Box>
<Button <Button

View File

@@ -1,4 +1,4 @@
import React, { useCallback, useMemo, useRef } from 'react'; import React, { useCallback, useRef } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { Box, Flex, IconButton, useTheme } from '@chakra-ui/react'; import { Box, Flex, IconButton, useTheme } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
@@ -71,8 +71,8 @@ const Detail = ({ kbId, currentTab }: { kbId: string; currentTab: `${TabEnum}` }
useQuery([kbId], () => getKbDetail(kbId), { useQuery([kbId], () => getKbDetail(kbId), {
onSuccess(res) { onSuccess(res) {
InfoRef.current?.initInput(res.tags);
form.reset(res); form.reset(res);
InfoRef.current?.initInput(res.tags);
}, },
onError(err: any) { onError(err: any) {
router.replace(`/kb/list`); router.replace(`/kb/list`);

View File

@@ -0,0 +1,165 @@
import React, { useCallback, useState, useRef } from 'react';
import { Box, Flex, Button, ModalHeader, ModalFooter, ModalBody, Input } from '@chakra-ui/react';
import { useSelectFile } from '@/hooks/useSelectFile';
import { useForm } from 'react-hook-form';
import { compressImg } from '@/utils/file';
import { getErrText } from '@/utils/tools';
import { useToast } from '@/hooks/useToast';
import { useRouter } from 'next/router';
import { useGlobalStore } from '@/store/global';
import { useRequest } from '@/hooks/useRequest';
import Avatar from '@/components/Avatar';
import MyTooltip from '@/components/MyTooltip';
import MyModal from '@/components/MyModal';
import { postCreateKb } from '@/api/plugins/kb';
import type { CreateKbParams } from '@/api/request/kb';
import { vectorModelList } from '@/store/static';
import MySelect from '@/components/Select';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import Tag from '@/components/Tag';
const CreateModal = ({ onClose }: { onClose: () => void }) => {
const [refresh, setRefresh] = useState(false);
const { toast } = useToast();
const router = useRouter();
const { isPc } = useGlobalStore();
const { register, setValue, getValues, handleSubmit } = useForm<CreateKbParams>({
defaultValues: {
avatar: '/icon/logo.svg',
name: '',
tags: [],
vectorModel: vectorModelList[0].model
}
});
const InputRef = useRef<HTMLInputElement>(null);
const { File, onOpen: onOpenSelectFile } = useSelectFile({
fileType: '.jpg,.png',
multiple: false
});
const onSelectFile = useCallback(
async (e: File[]) => {
const file = e[0];
if (!file) return;
try {
const src = await compressImg({
file,
maxW: 100,
maxH: 100
});
setValue('avatar', src);
setRefresh((state) => !state);
} catch (err: any) {
toast({
title: getErrText(err, '头像选择异常'),
status: 'warning'
});
}
},
[setValue, toast]
);
/* create a new kb and router to it */
const { mutate: onclickCreate, isLoading: creating } = useRequest({
mutationFn: async (data: CreateKbParams) => {
const id = await postCreateKb(data);
return id;
},
successToast: '创建成功',
errorToast: '创建知识库出现意外',
onSuccess(id) {
router.push(`/kb/detail?kbId=${id}`);
}
});
return (
<MyModal isOpen onClose={onClose} isCentered={!isPc} w={'400px'}>
<ModalHeader fontSize={'2xl'}></ModalHeader>
<ModalBody>
<Box color={'myGray.800'} fontWeight={'bold'}>
</Box>
<Flex mt={3} alignItems={'center'}>
<MyTooltip label={'点击设置头像'}>
<Avatar
flexShrink={0}
src={getValues('avatar')}
w={['28px', '32px']}
h={['28px', '32px']}
cursor={'pointer'}
borderRadius={'md'}
onClick={onOpenSelectFile}
/>
</MyTooltip>
<Input
ml={3}
flex={1}
autoFocus
bg={'myWhite.600'}
{...register('name', {
required: '知识库名称不能为空~'
})}
/>
</Flex>
<Flex mt={6} alignItems={'center'}>
<Box flex={'0 0 80px'}></Box>
<Box flex={1}>
<MySelect
w={'100%'}
value={getValues('vectorModel')}
list={vectorModelList.map((item) => ({
label: item.name,
value: item.model
}))}
onchange={(e) => {
setValue('vectorModel', e);
setRefresh((state) => !state);
}}
/>
</Box>
</Flex>
<Flex mt={6} alignItems={'center'} w={'100%'}>
<Box flex={'0 0 80px'}>
<MyTooltip label={'用空格隔开多个标签,便于搜索'} forceShow>
<QuestionOutlineIcon ml={1} />
</MyTooltip>
</Box>
<Input
flex={1}
ref={InputRef}
placeholder={'标签,使用空格分割。'}
maxLength={30}
onChange={(e) => {
setValue('tags', e.target.value.split(' '));
setRefresh(!refresh);
}}
/>
</Flex>
<Flex mt={2} flexWrap={'wrap'}>
{getValues('tags')
.filter((item) => item)
.map((item, i) => (
<Tag mr={2} mb={2} key={i} whiteSpace={'nowrap'}>
{item}
</Tag>
))}
</Flex>
</ModalBody>
<ModalFooter>
<Button variant={'base'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={creating} onClick={handleSubmit((data) => onclickCreate(data))}>
</Button>
</ModalFooter>
<File onSelect={onSelectFile} />
</MyModal>
);
};
export default CreateModal;

View File

@@ -1,5 +1,14 @@
import React, { useCallback } from 'react'; import React, { useCallback } from 'react';
import { Box, Card, Flex, Grid, useTheme, Button, IconButton } from '@chakra-ui/react'; import {
Box,
Card,
Flex,
Grid,
useTheme,
Button,
IconButton,
useDisclosure
} from '@chakra-ui/react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { useUserStore } from '@/store/user'; import { useUserStore } from '@/store/user';
import PageContainer from '@/components/PageContainer'; import PageContainer from '@/components/PageContainer';
@@ -7,12 +16,14 @@ import { useConfirm } from '@/hooks/useConfirm';
import { AddIcon } from '@chakra-ui/icons'; import { AddIcon } from '@chakra-ui/icons';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { delKbById, postCreateKb } from '@/api/plugins/kb'; import { delKbById } from '@/api/plugins/kb';
import { useRequest } from '@/hooks/useRequest';
import Avatar from '@/components/Avatar'; import Avatar from '@/components/Avatar';
import MyIcon from '@/components/Icon'; import MyIcon from '@/components/Icon';
import Tag from '@/components/Tag'; import Tag from '@/components/Tag';
import { serviceSideProps } from '@/utils/i18n'; import { serviceSideProps } from '@/utils/i18n';
import dynamic from 'next/dynamic';
const CreateModal = dynamic(() => import('./component/CreateModal'), { ssr: false });
const Kb = () => { const Kb = () => {
const theme = useTheme(); const theme = useTheme();
@@ -24,7 +35,13 @@ const Kb = () => {
}); });
const { myKbList, loadKbList, setKbList } = useUserStore(); const { myKbList, loadKbList, setKbList } = useUserStore();
useQuery(['loadKbList'], () => loadKbList()); const {
isOpen: isOpenCreateModal,
onOpen: onOpenCreateModal,
onClose: onCloseCreateModal
} = useDisclosure();
const { refetch } = useQuery(['loadKbList'], () => loadKbList());
/* 点击删除 */ /* 点击删除 */
const onclickDelKb = useCallback( const onclickDelKb = useCallback(
@@ -46,32 +63,13 @@ const Kb = () => {
[toast, setKbList, myKbList] [toast, setKbList, myKbList]
); );
/* create a new kb and router to it */
const { mutate: onclickCreate, isLoading } = useRequest({
mutationFn: async () => {
const name = `知识库${myKbList.length + 1}`;
const id = await postCreateKb({ name });
return id;
},
successToast: '创建成功',
errorToast: '创建知识库出现意外',
onSuccess(id) {
router.push(`/kb/detail?kbId=${id}`);
}
});
return ( return (
<PageContainer> <PageContainer>
<Flex pt={3} px={5} alignItems={'center'}> <Flex pt={3} px={5} alignItems={'center'}>
<Box flex={1} className="textlg" letterSpacing={1} fontSize={'24px'} fontWeight={'bold'}> <Box flex={1} className="textlg" letterSpacing={1} fontSize={'24px'} fontWeight={'bold'}>
</Box> </Box>
<Button <Button leftIcon={<AddIcon />} variant={'base'} onClick={onOpenCreateModal}>
isLoading={isLoading}
leftIcon={<AddIcon />}
variant={'base'}
onClick={onclickCreate}
>
</Button> </Button>
</Flex> </Flex>
@@ -141,6 +139,10 @@ const Kb = () => {
))} ))}
</Flex> </Flex>
</Box> </Box>
<Flex justifyContent={'flex-end'} alignItems={'center'} fontSize={'sm'}>
<MyIcon mr={1} name="kbTest" w={'12px'} />
<Box color={'myGray.500'}>{kb.vectorModelName}</Box>
</Flex>
</Card> </Card>
))} ))}
</Grid> </Grid>
@@ -153,6 +155,7 @@ const Kb = () => {
</Flex> </Flex>
)} )}
<ConfirmModal /> <ConfirmModal />
{isOpenCreateModal && <CreateModal onClose={onCloseCreateModal} />}
</PageContainer> </PageContainer>
); );
}; };

View File

@@ -57,7 +57,7 @@ const RegisterForm = ({ setPageType, loginSuccess }: Props) => {
username, username,
code, code,
password, password,
inviterId: localStorage.getItem('inviterId') || '' inviterId: localStorage.getItem('inviterId') || undefined
}) })
); );
toast({ toast({

View File

@@ -46,7 +46,7 @@ const provider = ({ code }: { code: string }) => {
if (loginStore.provider === 'git') { if (loginStore.provider === 'git') {
return gitLogin({ return gitLogin({
code, code,
inviterId: localStorage.getItem('inviterId') || '' inviterId: localStorage.getItem('inviterId') || undefined
}); });
} }
return null; return null;

View File

@@ -1,5 +1,5 @@
import { TrainingData } from '@/service/mongo'; import { TrainingData } from '@/service/mongo';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushQABill } from '@/service/events/pushBill';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData'; import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { TrainingModeEnum } from '@/constants/plugin'; import { TrainingModeEnum } from '@/constants/plugin';
import { ERROR_ENUM } from '../errorCode'; import { ERROR_ENUM } from '../errorCode';
@@ -60,14 +60,13 @@ export async function generateQA(): Promise<any> {
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
const response = await Promise.all( const response = await Promise.all(
[data.q].map((text) => { [data.q].map((text) => {
const modelTokenLimit = const modelTokenLimit = global.qaModel.maxToken || 16000;
chatModels.find((item) => item.model === data.model)?.contextMaxToken || 16000;
const messages: ChatCompletionRequestMessage[] = [ const messages: ChatCompletionRequestMessage[] = [
{ {
role: 'system', role: 'system',
content: `你是出题人. content: `你是出题人${
${data.prompt || '我会发送一段长文本'}. data.prompt || '我会发送一段长文本'
从中提取出 25 个问题和答案. 答案详细完整. 按下面格式返回: },请从中提取出 25 个问题和答案. 答案详细完整,并按下面格式返回:
Q1: Q1:
A1: A1:
Q2: Q2:
@@ -88,7 +87,7 @@ A2:
return chatAPI return chatAPI
.createChatCompletion( .createChatCompletion(
{ {
model: data.model, model: global.qaModel.model,
temperature: 0.8, temperature: 0.8,
messages, messages,
stream: false, stream: false,
@@ -106,10 +105,9 @@ A2:
const result = formatSplitText(answer || ''); // 格式化后的QA对 const result = formatSplitText(answer || ''); // 格式化后的QA对
console.log(`split result length: `, result.length); console.log(`split result length: `, result.length);
// 计费 // 计费
pushSplitDataBill({ pushQABill({
userId: data.userId, userId: data.userId,
totalTokens, totalTokens,
model: data.model,
appName: 'QA 拆分' appName: 'QA 拆分'
}); });
return { return {
@@ -135,7 +133,6 @@ A2:
source: data.source source: data.source
})), })),
userId, userId,
model: global.vectorModels[0].model,
mode: TrainingModeEnum.index mode: TrainingModeEnum.index
}); });

View File

@@ -38,7 +38,7 @@ export async function generateVector(): Promise<any> {
q: 1, q: 1,
a: 1, a: 1,
source: 1, source: 1,
model: 1 vectorModel: 1
}); });
// task preemption // task preemption
@@ -61,7 +61,7 @@ export async function generateVector(): Promise<any> {
// 生成词向量 // 生成词向量
const { vectors } = await getVector({ const { vectors } = await getVector({
model: data.model, model: data.vectorModel,
input: dataItems.map((item) => item.q), input: dataItems.map((item) => item.q),
userId userId
}); });

View File

@@ -76,13 +76,11 @@ export const updateShareChatBill = async ({
} }
}; };
export const pushSplitDataBill = async ({ export const pushQABill = async ({
userId, userId,
totalTokens, totalTokens,
model,
appName appName
}: { }: {
model: string;
userId: string; userId: string;
totalTokens: number; totalTokens: number;
appName: string; appName: string;
@@ -95,7 +93,7 @@ export const pushSplitDataBill = async ({
await connectToDatabase(); await connectToDatabase();
// 获取模型单价格, 都是用 gpt35 拆分 // 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = global.chatModels.find((item) => item.model === model)?.price || 3; const unitPrice = global.qaModel.price || 3;
// 计算价格 // 计算价格
const total = unitPrice * totalTokens; const total = unitPrice * totalTokens;

View File

@@ -19,7 +19,7 @@ const kbSchema = new Schema({
type: String, type: String,
required: true required: true
}, },
model: { vectorModel: {
type: String, type: String,
required: true, required: true,
default: 'text-embedding-ada-002' default: 'text-embedding-ada-002'

View File

@@ -28,9 +28,10 @@ const TrainingDataSchema = new Schema({
enum: Object.keys(TrainingTypeMap), enum: Object.keys(TrainingTypeMap),
required: true required: true
}, },
model: { vectorModel: {
type: String, type: String,
required: true required: true,
default: 'text-embedding-ada-002'
}, },
prompt: { prompt: {
// qa split prompt // qa split prompt

View File

@@ -181,7 +181,7 @@ export const dispatchChatCompletion = async (props: Record<string, any>): Promis
tokens: totalTokens, tokens: totalTokens,
question: userChatInput, question: userChatInput,
answer: answerText, answer: answerText,
maxToken, maxToken: max_tokens,
quoteList: filterQuoteQA, quoteList: filterQuoteQA,
completeMessages completeMessages
}, },
@@ -237,7 +237,7 @@ function getChatMessages({
}) { }) {
const limitText = (() => { const limitText = (() => {
if (limitPrompt) if (limitPrompt)
return `Use the provided content delimited by triple quotes to answer questions.${limitPrompt}`; return `Use the provided content delimited by triple quotes to answer questions. ${limitPrompt}`;
if (quotePrompt && !limitPrompt) { if (quotePrompt && !limitPrompt) {
return `Use the provided content delimited by triple quotes to answer questions.Your task is to answer the question using only the provided content. If the content does not contain the information needed to answer this question then simply write: "你的问题没有在知识库中体现".`; return `Use the provided content delimited by triple quotes to answer questions.Your task is to answer the question using only the provided content. If the content does not contain the information needed to answer this question then simply write: "你的问题没有在知识库中体现".`;
} }

View File

@@ -4,11 +4,7 @@ export const getChatModel = (model?: string) => {
export const getVectorModel = (model?: string) => { export const getVectorModel = (model?: string) => {
return global.vectorModels.find((item) => item.model === model); return global.vectorModels.find((item) => item.model === model);
}; };
export const getQAModel = (model?: string) => {
return global.qaModels.find((item) => item.model === model);
};
export const getModel = (model?: string) => { export const getModel = (model?: string) => {
return [...global.chatModels, ...global.vectorModels, ...global.qaModels].find( return [...global.chatModels, ...global.vectorModels].find((item) => item.model === model);
(item) => item.model === model
);
}; };

View File

@@ -9,7 +9,12 @@ import { delay } from '@/utils/tools';
import { FeConfigsType } from '@/types'; import { FeConfigsType } from '@/types';
export let chatModelList: ChatModelItemType[] = []; export let chatModelList: ChatModelItemType[] = [];
export let qaModelList: QAModelItemType[] = []; export let qaModel: QAModelItemType = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,
price: 0
};
export let vectorModelList: VectorModelItemType[] = []; export let vectorModelList: VectorModelItemType[] = [];
export let feConfigs: FeConfigsType = {}; export let feConfigs: FeConfigsType = {};
@@ -20,7 +25,7 @@ export const clientInitData = async (): Promise<InitDateResponse> => {
const res = await getInitData(); const res = await getInitData();
chatModelList = res.chatModels; chatModelList = res.chatModels;
qaModelList = res.qaModels; qaModel = res.qaModel;
vectorModelList = res.vectorModels; vectorModelList = res.vectorModels;
feConfigs = res.feConfigs; feConfigs = res.feConfigs;

View File

@@ -51,7 +51,7 @@ declare global {
var feConfigs: FeConfigsType; var feConfigs: FeConfigsType;
var systemEnv: SystemEnvType; var systemEnv: SystemEnvType;
var chatModels: ChatModelItemType[]; var chatModels: ChatModelItemType[];
var qaModels: QAModelItemType[]; var qaModel: QAModelItemType;
var vectorModels: VectorModelItemType[]; var vectorModels: VectorModelItemType[];
interface Window { interface Window {

View File

@@ -72,7 +72,7 @@ export interface TrainingDataSchema {
kbId: string; kbId: string;
expireAt: Date; expireAt: Date;
lockTime: Date; lockTime: Date;
model: string; vectorModel: string;
mode: `${TrainingModeEnum}`; mode: `${TrainingModeEnum}`;
prompt: string; prompt: string;
q: string; q: string;
@@ -164,7 +164,7 @@ export interface kbSchema {
updateTime: Date; updateTime: Date;
avatar: string; avatar: string;
name: string; name: string;
model: string; vectorModel: string;
tags: string[]; tags: string[];
} }

View File

@@ -7,10 +7,15 @@ export type KbListItemType = {
avatar: string; avatar: string;
name: string; name: string;
tags: string[]; tags: string[];
vectorModelName: string;
}; };
/* kb type */ /* kb type */
export interface KbItemType extends kbSchema { export interface KbItemType {
totalData: number; _id: string;
avatar: string;
name: string;
userId: string;
vectorModelName: string;
tags: string; tags: string;
} }

View File

@@ -213,14 +213,12 @@ docker-compose up -d
"defaultSystem": "" "defaultSystem": ""
} }
], ],
"QAModels": [ "QAModel": {
{ "model": "gpt-3.5-turbo-16k",
"model": "gpt-3.5-turbo-16k", "name": "GPT35-16k",
"name": "GPT35-16k", "maxToken": 16000,
"maxToken": 16000, "price": 0
"price": 0 },
}
],
"VectorModels": [ "VectorModels": [
{ {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",

View File

@@ -96,14 +96,12 @@ weight: 751
"defaultSystem": "" "defaultSystem": ""
} }
], ],
"QAModels": [ "QAModel": {
{ "model": "gpt-3.5-turbo-16k",
"model": "gpt-3.5-turbo-16k", "name": "GPT35-16k",
"name": "GPT35-16k", "maxToken": 16000,
"maxToken": 16000, "price": 0
"price": 0 },
}
],
"VectorModels": [ "VectorModels": [
{ {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",

View File

@@ -46,19 +46,17 @@
"defaultSystem": "" "defaultSystem": ""
} }
], ],
"QAModels": [
{
"model": "gpt-3.5-turbo-16k",
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0
}
],
"VectorModels": [ "VectorModels": [
{ {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
"name": "Embedding-2", "name": "Embedding-2",
"price": 0 "price": 0
} }
] ],
"QAModel": {
"model": "gpt-3.5-turbo-16k",
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0
}
} }