From be33794a5f38518ccb781fe38a377d835b123460 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 26 Aug 2023 18:24:16 +0800 Subject: [PATCH] feat: self vector search --- client/src/constants/flow/ModuleTemplate.ts | 4 +- client/src/constants/kb.ts | 8 ++- client/src/pages/api/openapi/kb/searchTest.ts | 15 +++--- client/src/pages/api/openapi/kb/updateData.ts | 7 +-- client/src/pages/api/plugins/kb/detail.ts | 4 +- client/src/pages/api/plugins/kb/list.ts | 4 +- .../app/detail/components/BasicEdit/index.tsx | 5 +- .../app/detail/components/KBSelectModal.tsx | 34 +++++++++++-- .../kb/detail/components/Import/Chunk.tsx | 50 ++++++++++++------- .../pages/kb/detail/components/Import/Csv.tsx | 17 +++++-- .../kb/detail/components/Import/Manual.tsx | 20 ++++++-- .../src/pages/kb/detail/components/Info.tsx | 2 +- .../src/pages/kb/detail/components/Test.tsx | 14 ++++-- client/src/pages/kb/detail/index.tsx | 14 +++--- client/src/pages/kb/list/index.tsx | 2 +- client/src/service/events/generateVector.ts | 1 - client/src/service/models/app.ts | 2 +- .../src/service/moduleDispatch/kb/search.ts | 6 +-- client/src/service/utils/data.ts | 2 +- client/src/types/model.d.ts | 2 + client/src/types/plugin.d.ts | 7 +-- client/src/utils/app.ts | 2 +- 22 files changed, 151 insertions(+), 71 deletions(-) diff --git a/client/src/constants/flow/ModuleTemplate.ts b/client/src/constants/flow/ModuleTemplate.ts index 6d7e9042b..527faaf8f 100644 --- a/client/src/constants/flow/ModuleTemplate.ts +++ b/client/src/constants/flow/ModuleTemplate.ts @@ -218,7 +218,7 @@ export const KBSearchModule: FlowModuleTemplateType = { key: 'similarity', type: FlowInputItemTypeEnum.slider, label: '相似度', - value: 0.8, + value: 0.4, min: 0, max: 1, step: 0.01, @@ -845,7 +845,7 @@ export const appTemplates: (AppItemType & { avatar: string; intro: string })[] = key: 'similarity', type: 'slider', label: '相似度', - value: 0.8, + value: 0.4, min: 0, max: 1, step: 0.01, diff --git a/client/src/constants/kb.ts b/client/src/constants/kb.ts index e60650fca..2abe977e9 100644 --- a/client/src/constants/kb.ts +++ b/client/src/constants/kb.ts @@ -6,5 +6,11 @@ export const defaultKbDetail: KbItemType = { avatar: '/icon/logo.svg', name: '', tags: '', - vectorModelName: 'text-embedding-ada-002' + vectorModel: { + model: 'text-embedding-ada-002', + name: 'Embedding-2', + price: 0.2, + defaultToken: 500, + maxToken: 8000 + } }; diff --git a/client/src/pages/api/openapi/kb/searchTest.ts b/client/src/pages/api/openapi/kb/searchTest.ts index fc43960d7..9283735b2 100644 --- a/client/src/pages/api/openapi/kb/searchTest.ts +++ b/client/src/pages/api/openapi/kb/searchTest.ts @@ -6,9 +6,9 @@ import { withNextCors } from '@/service/utils/tools'; import { getVector } from '../plugin/vector'; import type { KbTestItemType } from '@/types/plugin'; import { PgTrainingTableName } from '@/constants/plugin'; +import { KB } from '@/service/mongo'; export type Props = { - model: string; kbId: string; text: string; }; @@ -16,21 +16,24 @@ export type Response = KbTestItemType['results']; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { kbId, text, model } = req.body as Props; + const { kbId, text } = req.body as Props; - if (!kbId || !text || !model) { + if (!kbId || !text) { throw new Error('缺少参数'); } // 凭证校验 - const { userId } = await authUser({ req }); + const [{ userId }, kb] = await Promise.all([ + authUser({ req }), + KB.findById(kbId, 'vectorModel') + ]); - if (!userId) { + if (!userId || !kb) { throw new Error('缺少用户ID'); } const { vectors } = await getVector({ - model, + model: kb.vectorModel, userId, input: [text] }); diff --git a/client/src/pages/api/openapi/kb/updateData.ts b/client/src/pages/api/openapi/kb/updateData.ts index fec5683c1..3c66b2976 100644 --- a/client/src/pages/api/openapi/kb/updateData.ts +++ b/client/src/pages/api/openapi/kb/updateData.ts @@ -24,11 +24,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex await connectToDatabase(); - // 凭证校验 - const { userId } = await authUser({ req }); - - // find model - const kb = await KB.findById(kbId, 'model'); + // auth user and get kb + const [{ userId }, kb] = await Promise.all([authUser({ req }), KB.findById(kbId, 'model')]); if (!kb) { throw new Error("Can't find database"); diff --git a/client/src/pages/api/plugins/kb/detail.ts b/client/src/pages/api/plugins/kb/detail.ts index fc5920450..5bb3cb3ec 100644 --- a/client/src/pages/api/plugins/kb/detail.ts +++ b/client/src/pages/api/plugins/kb/detail.ts @@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; -import { getModel } from '@/service/utils/data'; +import { getVectorModel } from '@/service/utils/data'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -34,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< avatar: data.avatar, name: data.name, userId: data.userId, - vectorModelName: getModel(data.vectorModel)?.name || 'Unknown', + vectorModel: getVectorModel(data.vectorModel), tags: data.tags.join(' ') } }); diff --git a/client/src/pages/api/plugins/kb/list.ts b/client/src/pages/api/plugins/kb/list.ts index 58cbfd3b4..a105265a1 100644 --- a/client/src/pages/api/plugins/kb/list.ts +++ b/client/src/pages/api/plugins/kb/list.ts @@ -3,7 +3,7 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase, KB } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { KbListItemType } from '@/types/plugin'; -import { getModel } from '@/service/utils/data'; +import { getVectorModel } from '@/service/utils/data'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -25,7 +25,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< avatar: item.avatar, name: item.name, tags: item.tags, - vectorModelName: getModel(item.vectorModel)?.name || 'UnKnow' + vectorModel: getVectorModel(item.vectorModel) })) ); diff --git a/client/src/pages/app/detail/components/BasicEdit/index.tsx b/client/src/pages/app/detail/components/BasicEdit/index.tsx index 6ebeef621..ceb9e895d 100644 --- a/client/src/pages/app/detail/components/BasicEdit/index.tsx +++ b/client/src/pages/app/detail/components/BasicEdit/index.tsx @@ -542,7 +542,10 @@ const Settings = ({ appId }: { appId: string }) => { {isOpenKbSelect && ( ({ kbId: item._id }))} + activeKbs={selectedKbList.map((item) => ({ + kbId: item._id, + vectorModel: item.vectorModel + }))} onClose={onCloseKbSelect} onChange={replaceKbList} /> diff --git a/client/src/pages/app/detail/components/KBSelectModal.tsx b/client/src/pages/app/detail/components/KBSelectModal.tsx index 4d3f40d47..a9b953190 100644 --- a/client/src/pages/app/detail/components/KBSelectModal.tsx +++ b/client/src/pages/app/detail/components/KBSelectModal.tsx @@ -16,9 +16,11 @@ import { useForm } from 'react-hook-form'; import { QuestionOutlineIcon } from '@chakra-ui/icons'; import type { SelectedKbType } from '@/types/plugin'; import { useGlobalStore } from '@/store/global'; +import { useToast } from '@/hooks/useToast'; import MySlider from '@/components/Slider'; import MyTooltip from '@/components/MyTooltip'; import MyModal from '@/components/MyModal'; +import MyIcon from '@/components/Icon'; export type KbParamsType = { searchSimilarity: number; @@ -40,6 +42,7 @@ export const KBSelectModal = ({ const theme = useTheme(); const [selectedKbList, setSelectedKbList] = useState(activeKbs); const { isPc } = useGlobalStore(); + const { toast } = useToast(); return ( - 关联的知识库({selectedKbList.length}) + + 关联的知识库({selectedKbList.length}) + + 仅能选择同一个索引模型的知识库 + + + {kbList.map((item) => (() => { @@ -84,7 +94,18 @@ export const KBSelectModal = ({ if (selected) { setSelectedKbList((state) => state.filter((kb) => kb.kbId !== item._id)); } else { - setSelectedKbList((state) => [...state, { kbId: item._id }]); + const vectorModel = selectedKbList[0]?.vectorModel?.model; + + if (vectorModel && vectorModel !== item.vectorModel.model) { + return toast({ + status: 'warning', + title: '仅能选择同一个索引模型的知识库' + }); + } + setSelectedKbList((state) => [ + ...state, + { kbId: item._id, vectorModel: item.vectorModel } + ]); } }} > @@ -94,6 +115,10 @@ export const KBSelectModal = ({ {item.name} + + + {item.vectorModel.name} + ); })() @@ -138,7 +163,10 @@ export const KbParamsModal = ({ 相似度 - + diff --git a/client/src/pages/kb/detail/components/Import/Chunk.tsx b/client/src/pages/kb/detail/components/Import/Chunk.tsx index 36e9f1993..b93f6a45f 100644 --- a/client/src/pages/kb/detail/components/Import/Chunk.tsx +++ b/client/src/pages/kb/detail/components/Import/Chunk.tsx @@ -19,7 +19,6 @@ import { postKbDataFromList } from '@/api/plugins/kb'; import { splitText2Chunks } from '@/utils/file'; import { getErrText } from '@/utils/tools'; import { formatPrice } from '@/utils/user'; -import { vectorModelList } from '@/store/static'; import MyIcon from '@/components/Icon'; import CloseIcon from '@/components/Icon/close'; import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete'; @@ -27,17 +26,20 @@ import MyTooltip from '@/components/MyTooltip'; import { QuestionOutlineIcon } from '@chakra-ui/icons'; import { TrainingModeEnum } from '@/constants/plugin'; import FileSelect, { type FileItemType } from './FileSelect'; +import { useUserStore } from '@/store/user'; const fileExtension = '.txt, .doc, .docx, .pdf, .md'; const ChunkImport = ({ kbId }: { kbId: string }) => { - const model = vectorModelList[0]?.model || 'text-embedding-ada-002'; - const unitPrice = vectorModelList[0]?.price || 0.2; + const { kbDetail } = useUserStore(); + + const vectorModel = kbDetail.vectorModel; + const unitPrice = vectorModel?.price || 0.2; const theme = useTheme(); const router = useRouter(); const { toast } = useToast(); - const [chunkLen, setChunkLen] = useState(500); + const [chunkLen, setChunkLen] = useState(vectorModel?.defaultToken || 300); const [showRePreview, setShowRePreview] = useState(false); const [files, setFiles] = useState([]); const [previewFile, setPreviewFile] = useState(); @@ -205,24 +207,34 @@ const ChunkImport = ({ kbId }: { kbId: string }) => { - { - setChunkLen(+e); - setShowRePreview(true); + css={{ + '& > span': { + display: 'block' + } }} > - - - - - - + + { + setChunkLen(+e); + setShowRePreview(true); + }} + > + + + + + + + + {/* price */} diff --git a/client/src/pages/kb/detail/components/Import/Csv.tsx b/client/src/pages/kb/detail/components/Import/Csv.tsx index db110b639..ff418cd9c 100644 --- a/client/src/pages/kb/detail/components/Import/Csv.tsx +++ b/client/src/pages/kb/detail/components/Import/Csv.tsx @@ -11,11 +11,13 @@ import DeleteIcon, { hoverDeleteStyles } from '@/components/Icon/delete'; import { TrainingModeEnum } from '@/constants/plugin'; import FileSelect, { type FileItemType } from './FileSelect'; import { useRouter } from 'next/router'; +import { useUserStore } from '@/store/user'; const fileExtension = '.csv'; const CsvImport = ({ kbId }: { kbId: string }) => { - const model = vectorModelList[0]?.model; + const { kbDetail } = useUserStore(); + const theme = useTheme(); const router = useRouter(); const { toast } = useToast(); @@ -37,13 +39,22 @@ const CsvImport = ({ kbId }: { kbId: string }) => { mutationFn: async () => { const chunks = files.map((file) => file.chunks).flat(); + const filterChunks = chunks.filter((item) => item.q.length < kbDetail.vectorModel.maxToken); + + if (filterChunks.length !== chunks.length) { + toast({ + title: `${chunks.length - filterChunks.length}条数据超出长度,已被过滤`, + status: 'info' + }); + } + // subsection import let success = 0; const step = 500; - for (let i = 0; i < chunks.length; i += step) { + for (let i = 0; i < filterChunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, - data: chunks.slice(i, i + step), + data: filterChunks.slice(i, i + step), mode: TrainingModeEnum.index }); diff --git a/client/src/pages/kb/detail/components/Import/Manual.tsx b/client/src/pages/kb/detail/components/Import/Manual.tsx index ddb842930..e91eff379 100644 --- a/client/src/pages/kb/detail/components/Import/Manual.tsx +++ b/client/src/pages/kb/detail/components/Import/Manual.tsx @@ -1,4 +1,4 @@ -import React from 'react'; +import React, { useState } from 'react'; import { Box, Textarea, Button } from '@chakra-ui/react'; import { useForm } from 'react-hook-form'; import { useToast } from '@/hooks/useToast'; @@ -6,14 +6,18 @@ import { useRequest } from '@/hooks/useRequest'; import { getErrText } from '@/utils/tools'; import { postKbDataFromList } from '@/api/plugins/kb'; import { TrainingModeEnum } from '@/constants/plugin'; +import { useUserStore } from '@/store/user'; type ManualFormType = { q: string; a: string }; const ManualImport = ({ kbId }: { kbId: string }) => { + const { kbDetail } = useUserStore(); + const { register, handleSubmit, reset } = useForm({ defaultValues: { q: '', a: '' } }); const { toast } = useToast(); + const [qLen, setQLen] = useState(0); const { mutate: onImportData, isLoading } = useRequest({ mutationFn: async (e: ManualFormType) => { @@ -64,16 +68,22 @@ const ManualImport = ({ kbId }: { kbId: string }) => { return ( - + {'匹配的知识点'}