feat: self vector search

This commit is contained in:
archer
2023-08-26 18:24:16 +08:00
parent 13439c5183
commit be33794a5f
22 changed files with 151 additions and 71 deletions

View File

@@ -218,7 +218,7 @@ export const KBSearchModule: FlowModuleTemplateType = {
key: 'similarity',
type: FlowInputItemTypeEnum.slider,
label: '相似度',
value: 0.8,
value: 0.4,
min: 0,
max: 1,
step: 0.01,
@@ -845,7 +845,7 @@ export const appTemplates: (AppItemType & { avatar: string; intro: string })[] =
key: 'similarity',
type: 'slider',
label: '相似度',
value: 0.8,
value: 0.4,
min: 0,
max: 1,
step: 0.01,

View File

@@ -6,5 +6,11 @@ export const defaultKbDetail: KbItemType = {
avatar: '/icon/logo.svg',
name: '',
tags: '',
vectorModelName: 'text-embedding-ada-002'
vectorModel: {
model: 'text-embedding-ada-002',
name: 'Embedding-2',
price: 0.2,
defaultToken: 500,
maxToken: 8000
}
};

View File

@@ -6,9 +6,9 @@ import { withNextCors } from '@/service/utils/tools';
import { getVector } from '../plugin/vector';
import type { KbTestItemType } from '@/types/plugin';
import { PgTrainingTableName } from '@/constants/plugin';
import { KB } from '@/service/mongo';
export type Props = {
model: string;
kbId: string;
text: string;
};
@@ -16,21 +16,24 @@ export type Response = KbTestItemType['results'];
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, text, model } = req.body as Props;
const { kbId, text } = req.body as Props;
if (!kbId || !text || !model) {
if (!kbId || !text) {
throw new Error('缺少参数');
}
// 凭证校验
const { userId } = await authUser({ req });
const [{ userId }, kb] = await Promise.all([
authUser({ req }),
KB.findById(kbId, 'vectorModel')
]);
if (!userId) {
if (!userId || !kb) {
throw new Error('缺少用户ID');
}
const { vectors } = await getVector({
model,
model: kb.vectorModel,
userId,
input: [text]
});

View File

@@ -24,11 +24,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
await connectToDatabase();
// 凭证校验
const { userId } = await authUser({ req });
// find model
const kb = await KB.findById(kbId, 'model');
// auth user and get kb
const [{ userId }, kb] = await Promise.all([authUser({ req }), KB.findById(kbId, 'model')]);
if (!kb) {
throw new Error("Can't find database");

View File

@@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { getModel } from '@/service/utils/data';
import { getVectorModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -34,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
avatar: data.avatar,
name: data.name,
userId: data.userId,
vectorModelName: getModel(data.vectorModel)?.name || 'Unknown',
vectorModel: getVectorModel(data.vectorModel),
tags: data.tags.join(' ')
}
});

View File

@@ -3,7 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { KbListItemType } from '@/types/plugin';
import { getModel } from '@/service/utils/data';
import { getVectorModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -25,7 +25,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
avatar: item.avatar,
name: item.name,
tags: item.tags,
vectorModelName: getModel(item.vectorModel)?.name || 'UnKnow'
vectorModel: getVectorModel(item.vectorModel)
}))
);

View File

@@ -542,7 +542,10 @@ const Settings = ({ appId }: { appId: string }) => {
{isOpenKbSelect && (
<KBSelectModal
kbList={myKbList}
activeKbs={selectedKbList.map((item) => ({ kbId: item._id }))}
activeKbs={selectedKbList.map((item) => ({
kbId: item._id,
vectorModel: item.vectorModel
}))}
onClose={onCloseKbSelect}
onChange={replaceKbList}
/>

View File

@@ -16,9 +16,11 @@ import { useForm } from 'react-hook-form';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { SelectedKbType } from '@/types/plugin';
import { useGlobalStore } from '@/store/global';
import { useToast } from '@/hooks/useToast';
import MySlider from '@/components/Slider';
import MyTooltip from '@/components/MyTooltip';
import MyModal from '@/components/MyModal';
import MyIcon from '@/components/Icon';
export type KbParamsType = {
searchSimilarity: number;
@@ -40,6 +42,7 @@ export const KBSelectModal = ({
const theme = useTheme();
const [selectedKbList, setSelectedKbList] = useState<SelectedKbType>(activeKbs);
const { isPc } = useGlobalStore();
const { toast } = useToast();
return (
<MyModal
@@ -50,7 +53,13 @@ export const KBSelectModal = ({
onClose={onClose}
>
<Flex flexDirection={'column'} h={['90vh', 'auto']}>
<ModalHeader>({selectedKbList.length})</ModalHeader>
<ModalHeader>
<Box>({selectedKbList.length})</Box>
<Box fontSize={'sm'} color={'myGray.500'} fontWeight={'normal'}>
</Box>
</ModalHeader>
<ModalBody
flex={['1 0 0', '0 0 auto']}
maxH={'80vh'}
@@ -58,6 +67,7 @@ export const KBSelectModal = ({
display={'grid'}
gridTemplateColumns={['repeat(1,1fr)', 'repeat(2,1fr)', 'repeat(3,1fr)']}
gridGap={3}
userSelect={'none'}
>
{kbList.map((item) =>
(() => {
@@ -84,7 +94,18 @@ export const KBSelectModal = ({
if (selected) {
setSelectedKbList((state) => state.filter((kb) => kb.kbId !== item._id));
} else {
setSelectedKbList((state) => [...state, { kbId: item._id }]);
const vectorModel = selectedKbList[0]?.vectorModel?.model;
if (vectorModel && vectorModel !== item.vectorModel.model) {
return toast({
status: 'warning',
title: '仅能选择同一个索引模型的知识库'
});
}
setSelectedKbList((state) => [
...state,
{ kbId: item._id, vectorModel: item.vectorModel }
]);
}
}}
>
@@ -94,6 +115,10 @@ export const KBSelectModal = ({
{item.name}
</Box>
</Flex>
<Flex justifyContent={'flex-end'} alignItems={'center'} fontSize={'sm'}>
<MyIcon mr={1} name="kbTest" w={'12px'} />
<Box color={'myGray.500'}>{item.vectorModel.name}</Box>
</Flex>
</Card>
);
})()
@@ -138,7 +163,10 @@ export const KbParamsModal = ({
<Box display={['block', 'flex']} py={5} pt={[0, 5]}>
<Box flex={'0 0 100px'} mb={[8, 0]}>
<MyTooltip label={'高相似度推荐0.8及以上。'} forceShow>
<MyTooltip
label={'不同索引模型的相似度有区别,请通过搜索测试来选择合适的数值'}
forceShow
>
<QuestionOutlineIcon ml={1} />
</MyTooltip>
</Box>

View File

@@ -19,7 +19,6 @@ import { postKbDataFromList } from '@/api/plugins/kb';
import { splitText2Chunks } from '@/utils/file';
import { getErrText } from '@/utils/tools';
import { formatPrice } from '@/utils/user';
import { vectorModelList } from '@/store/static';
import MyIcon from '@/components/Icon';
import CloseIcon from '@/components/Icon/close';
import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete';
@@ -27,17 +26,20 @@ import MyTooltip from '@/components/MyTooltip';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import { TrainingModeEnum } from '@/constants/plugin';
import FileSelect, { type FileItemType } from './FileSelect';
import { useUserStore } from '@/store/user';
const fileExtension = '.txt, .doc, .docx, .pdf, .md';
const ChunkImport = ({ kbId }: { kbId: string }) => {
const model = vectorModelList[0]?.model || 'text-embedding-ada-002';
const unitPrice = vectorModelList[0]?.price || 0.2;
const { kbDetail } = useUserStore();
const vectorModel = kbDetail.vectorModel;
const unitPrice = vectorModel?.price || 0.2;
const theme = useTheme();
const router = useRouter();
const { toast } = useToast();
const [chunkLen, setChunkLen] = useState(500);
const [chunkLen, setChunkLen] = useState(vectorModel?.defaultToken || 300);
const [showRePreview, setShowRePreview] = useState(false);
const [files, setFiles] = useState<FileItemType[]>([]);
const [previewFile, setPreviewFile] = useState<FileItemType>();
@@ -205,24 +207,34 @@ const ChunkImport = ({ kbId }: { kbId: string }) => {
<QuestionOutlineIcon ml={1} />
</MyTooltip>
</Box>
<NumberInput
ml={4}
<Box
flex={1}
defaultValue={chunkLen}
min={300}
max={2000}
step={10}
onChange={(e) => {
setChunkLen(+e);
setShowRePreview(true);
css={{
'& > span': {
display: 'block'
}
}}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
<MyTooltip label={`范围: 100~${kbDetail.vectorModel.maxToken}`}>
<NumberInput
ml={4}
defaultValue={chunkLen}
min={100}
max={kbDetail.vectorModel.maxToken}
step={10}
onChange={(e) => {
setChunkLen(+e);
setShowRePreview(true);
}}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</MyTooltip>
</Box>
</Flex>
{/* price */}
<Flex py={5} alignItems={'center'}>

View File

@@ -11,11 +11,13 @@ import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete';
import { TrainingModeEnum } from '@/constants/plugin';
import FileSelect, { type FileItemType } from './FileSelect';
import { useRouter } from 'next/router';
import { useUserStore } from '@/store/user';
const fileExtension = '.csv';
const CsvImport = ({ kbId }: { kbId: string }) => {
const model = vectorModelList[0]?.model;
const { kbDetail } = useUserStore();
const theme = useTheme();
const router = useRouter();
const { toast } = useToast();
@@ -37,13 +39,22 @@ const CsvImport = ({ kbId }: { kbId: string }) => {
mutationFn: async () => {
const chunks = files.map((file) => file.chunks).flat();
const filterChunks = chunks.filter((item) => item.q.length < kbDetail.vectorModel.maxToken);
if (filterChunks.length !== chunks.length) {
toast({
title: `${chunks.length - filterChunks.length}条数据超出长度,已被过滤`,
status: 'info'
});
}
// subsection import
let success = 0;
const step = 500;
for (let i = 0; i < chunks.length; i += step) {
for (let i = 0; i < filterChunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,
data: chunks.slice(i, i + step),
data: filterChunks.slice(i, i + step),
mode: TrainingModeEnum.index
});

View File

@@ -1,4 +1,4 @@
import React from 'react';
import React, { useState } from 'react';
import { Box, Textarea, Button } from '@chakra-ui/react';
import { useForm } from 'react-hook-form';
import { useToast } from '@/hooks/useToast';
@@ -6,14 +6,18 @@ import { useRequest } from '@/hooks/useRequest';
import { getErrText } from '@/utils/tools';
import { postKbDataFromList } from '@/api/plugins/kb';
import { TrainingModeEnum } from '@/constants/plugin';
import { useUserStore } from '@/store/user';
type ManualFormType = { q: string; a: string };
const ManualImport = ({ kbId }: { kbId: string }) => {
const { kbDetail } = useUserStore();
const { register, handleSubmit, reset } = useForm({
defaultValues: { q: '', a: '' }
});
const { toast } = useToast();
const [qLen, setQLen] = useState(0);
const { mutate: onImportData, isLoading } = useRequest({
mutationFn: async (e: ManualFormType) => {
@@ -64,16 +68,22 @@ const ManualImport = ({ kbId }: { kbId: string }) => {
return (
<Box p={[4, 8]} h={'100%'} overflow={'overlay'}>
<Box display={'flex'} flexDirection={['column', 'row']}>
<Box flex={1} mr={[0, 4]} mb={[4, 0]} h={['50%', '100%']}>
<Box flex={1} mr={[0, 4]} mb={[4, 0]} h={['50%', '100%']} position={'relative'}>
<Box h={'30px'}>{'匹配的知识点'}</Box>
<Textarea
placeholder={'匹配的知识点。这部分内容会被搜索,请把控内容的质量。总和最多 3000 字。'}
maxLength={3000}
placeholder={`匹配的知识点。这部分内容会被搜索,请把控内容的质量。最多 ${kbDetail.vectorModel.maxToken} 字。`}
maxLength={kbDetail.vectorModel.maxToken}
h={['250px', '500px']}
{...register(`q`, {
required: true
required: true,
onChange(e) {
setQLen(e.target.value.length);
}
})}
/>
<Box position={'absolute'} color={'myGray.500'} right={5} bottom={3} zIndex={99}>
{qLen}
</Box>
</Box>
<Box flex={1} h={['50%', '100%']}>
<Box h={'30px'}></Box>

View File

@@ -154,7 +154,7 @@ const Info = (
<Box flex={['0 0 90px', '0 0 160px']} w={0}>
</Box>
<Box flex={[1, '0 0 300px']}>{getValues('vectorModelName')}</Box>
<Box flex={[1, '0 0 300px']}>{getValues('vectorModel').name}</Box>
</Flex>
<Flex mt={5} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}>

View File

@@ -10,13 +10,15 @@ import InputDataModal, { type FormData } from './InputDataModal';
import { useGlobalStore } from '@/store/global';
import { getErrText } from '@/utils/tools';
import { useToast } from '@/hooks/useToast';
import { vectorModelList } from '@/store/static';
import { customAlphabet } from 'nanoid';
import MyTooltip from '@/components/MyTooltip';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import { useUserStore } from '@/store/user';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const Test = ({ kbId }: { kbId: string }) => {
const { kbDetail } = useUserStore();
const theme = useTheme();
const { toast } = useToast();
const { setLoading } = useGlobalStore();
@@ -31,7 +33,7 @@ const Test = ({ kbId }: { kbId: string }) => {
);
const { mutate, isLoading } = useRequest({
mutationFn: () => searchText({ model: vectorModelList[0].model, kbId, text: inputText.trim() }),
mutationFn: () => searchText({ kbId, text: inputText.trim() }),
onSuccess(res) {
const testItem = {
id: nanoid(),
@@ -75,12 +77,15 @@ const Test = ({ kbId }: { kbId: string }) => {
rows={6}
resize={'none'}
variant={'unstyled'}
maxLength={1000}
maxLength={kbDetail.vectorModel.maxToken}
placeholder="输入需要测试的文本"
value={inputText}
onChange={(e) => setInputText(e.target.value)}
/>
<Flex justifyContent={'flex-end'}>
<Flex alignItems={'center'} justifyContent={'flex-end'}>
<Box mr={3} color={'myGray.500'}>
{inputText.length}
</Box>
<Button isDisabled={inputText === ''} isLoading={isLoading} onClick={mutate}>
</Button>
@@ -177,6 +182,7 @@ const Test = ({ kbId }: { kbId: string }) => {
'repeat(1,1fr)',
'repeat(1,1fr)',
'repeat(1,1fr)',
'repeat(1,1fr)',
'repeat(2,1fr)'
]}
gridGap={4}

View File

@@ -165,12 +165,14 @@ const Detail = ({ kbId, currentTab }: { kbId: string; currentTab: `${TabEnum}` }
</Box>
)}
<Box flex={'1 0 0'} h={'100%'} pb={[4, 0]}>
{currentTab === TabEnum.data && <DataCard kbId={kbId} />}
{currentTab === TabEnum.import && <ImportData kbId={kbId} />}
{currentTab === TabEnum.test && <Test kbId={kbId} />}
{currentTab === TabEnum.info && <Info ref={InfoRef} kbId={kbId} form={form} />}
</Box>
{!!kbDetail._id && (
<Box flex={'1 0 0'} h={'100%'} pb={[4, 0]}>
{currentTab === TabEnum.data && <DataCard kbId={kbId} />}
{currentTab === TabEnum.import && <ImportData kbId={kbId} />}
{currentTab === TabEnum.test && <Test kbId={kbId} />}
{currentTab === TabEnum.info && <Info ref={InfoRef} kbId={kbId} form={form} />}
</Box>
)}
</Box>
</PageContainer>
);

View File

@@ -141,7 +141,7 @@ const Kb = () => {
</Box>
<Flex justifyContent={'flex-end'} alignItems={'center'} fontSize={'sm'}>
<MyIcon mr={1} name="kbTest" w={'12px'} />
<Box color={'myGray.500'}>{kb.vectorModelName}</Box>
<Box color={'myGray.500'}>{kb.vectorModel.name}</Box>
</Flex>
</Card>
))}

View File

@@ -1,4 +1,3 @@
import { openaiAccountError } from '../errorCode';
import { insertKbItem } from '@/service/pg';
import { getVector } from '@/pages/api/openapi/plugin/vector';
import { TrainingData } from '../models/trainingData';

View File

@@ -65,7 +65,7 @@ const AppSchema = new Schema({
},
searchSimilarity: {
type: Number,
default: 0.8
default: 0.4
},
searchLimit: {
type: Number,

View File

@@ -1,5 +1,5 @@
import { PgClient } from '@/service/pg';
import type { ChatHistoryItemResType, ChatItemType } from '@/types/chat';
import type { ChatHistoryItemResType } from '@/types/chat';
import { ChatModuleEnum, TaskResponseKeyEnum } from '@/constants/chat';
import { getVector } from '@/pages/api/openapi/plugin/vector';
import { countModelPrice } from '@/service/events/pushBill';
@@ -21,7 +21,7 @@ export type KBSearchResponse = {
};
export async function dispatchKBSearch(props: Record<string, any>): Promise<KBSearchResponse> {
const { kbList = [], similarity = 0.8, limit = 5, userChatInput } = props as KBSearchProps;
const { kbList = [], similarity = 0.4, limit = 5, userChatInput } = props as KBSearchProps;
if (kbList.length === 0) {
return Promise.reject("You didn't choose the knowledge base");
@@ -32,7 +32,7 @@ export async function dispatchKBSearch(props: Record<string, any>): Promise<KBSe
}
// get vector
const vectorModel = global.vectorModels[0];
const vectorModel = kbList[0]?.vectorModel;
const { vectors, tokenLen } = await getVector({
model: vectorModel.model,
input: [userChatInput]

View File

@@ -2,7 +2,7 @@ export const getChatModel = (model?: string) => {
return global.chatModels.find((item) => item.model === model);
};
export const getVectorModel = (model?: string) => {
return global.vectorModels.find((item) => item.model === model);
return global.vectorModels.find((item) => item.model === model) || global.vectorModels[0];
};
export const getModel = (model?: string) => {

View File

@@ -17,5 +17,7 @@ export type QAModelItemType = {
export type VectorModelItemType = {
model: string;
name: string;
defaultToken: number;
price: number;
maxToken: number;
};

View File

@@ -1,13 +1,14 @@
import { VectorModelItemType } from './model';
import type { kbSchema } from './mongoSchema';
export type SelectedKbType = { kbId: string }[];
export type SelectedKbType = { kbId: string; vectorModel: VectorModelItemType }[];
export type KbListItemType = {
_id: string;
avatar: string;
name: string;
tags: string[];
vectorModelName: string;
vectorModel: VectorModelItemType;
};
/* kb type */
export interface KbItemType {
@@ -15,7 +16,7 @@ export interface KbItemType {
avatar: string;
name: string;
userId: string;
vectorModelName: string;
vectorModel: VectorModelItemType;
tags: string;
}

View File

@@ -49,7 +49,7 @@ export const getDefaultAppForm = (): EditFormType => {
},
kb: {
list: [],
searchSimilarity: 0.8,
searchSimilarity: 0.4,
searchLimit: 5,
searchEmptyText: ''
},