From f52f514f5f1eed1667f50b3a2212b2f82eec3cfb Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 9 May 2023 13:20:52 +0800 Subject: [PATCH] feat: callback training data --- src/api/model.ts | 5 ++++- .../{getSplitData.ts => getTrainingData.ts} | 20 +++++++++++++++++-- src/pages/api/model/data/splitData.ts | 4 ++-- .../detail/components/ModelDataCard.tsx | 12 +++++++---- 4 files changed, 32 insertions(+), 9 deletions(-) rename src/pages/api/model/data/{getSplitData.ts => getTrainingData.ts} (61%) diff --git a/src/api/model.ts b/src/api/model.ts index eac5469c4..b0bcfc0e8 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -52,7 +52,10 @@ export const getExportDataList = (modelId: string) => * 获取模型正在拆分数据的数量 */ export const getModelSplitDataListLen = (modelId: string) => - GET(`/model/data/getSplitData?modelId=${modelId}`); + GET<{ + splitDataQueue: number; + embeddingQueue: number; + }>(`/model/data/getTrainingData?modelId=${modelId}`); /** * 获取 web 页面内容 diff --git a/src/pages/api/model/data/getSplitData.ts b/src/pages/api/model/data/getTrainingData.ts similarity index 61% rename from src/pages/api/model/data/getSplitData.ts rename to src/pages/api/model/data/getTrainingData.ts index 3b4ba6355..e95aa5c86 100644 --- a/src/pages/api/model/data/getSplitData.ts +++ b/src/pages/api/model/data/getTrainingData.ts @@ -2,6 +2,8 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, SplitData, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/auth'; +import { ModelDataStatusEnum } from '@/constants/model'; +import { PgClient } from '@/service/pg'; /* 拆分数据成QA */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -14,15 +16,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const userId = await authToken(req); - // 找到长度大于0的数据 + // split queue data const data = await SplitData.find({ userId, modelId, textList: { $exists: true, $not: { $size: 0 } } }); + // embedding queue data + const where: any = [ + ['user_id', userId], + 'AND', + ['model_id', modelId], + 'AND', + ['status', ModelDataStatusEnum.waiting] + ]; + jsonRes(res, { - data: data.map((item) => item.textList).flat().length + data: { + splitDataQueue: data.map((item) => item.textList).flat().length, + embeddingQueue: await PgClient.count('modelData', { + where + }) + } }); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts index 9836adb7a..4ac3bde10 100644 --- a/src/pages/api/model/data/splitData.ts +++ b/src/pages/api/model/data/splitData.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, SplitData, Model } from '@/service/mongo'; +import { connectToDatabase, SplitData } from '@/service/mongo'; import { authModel, authToken } from '@/service/utils/auth'; import { generateVector } from '@/service/events/generateVector'; import { generateQA } from '@/service/events/generateQA'; @@ -65,7 +65,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) export const config = { api: { bodyParser: { - sizeLimit: '10mb' + sizeLimit: '100mb' } } }; diff --git a/src/pages/model/components/detail/components/ModelDataCard.tsx b/src/pages/model/components/detail/components/ModelDataCard.tsx index 92d704d9a..8e6d20d1a 100644 --- a/src/pages/model/components/detail/components/ModelDataCard.tsx +++ b/src/pages/model/components/detail/components/ModelDataCard.tsx @@ -88,7 +88,7 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean onClose: onCloseSelectCsvModal } = useDisclosure(); - const { data: splitDataLen = 0, refetch } = useQuery( + const { data: { splitDataQueue = 0, embeddingQueue = 0 } = {}, refetch } = useQuery( ['getModelSplitDataList'], () => getModelSplitDataListLen(modelId), { @@ -109,7 +109,7 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean useQuery(['refetchData'], () => refetchData(pageNum), { refetchInterval: 5000, - enabled: splitDataLen > 0 + enabled: splitDataQueue > 0 || embeddingQueue > 0 }); // 获取所有的数据,并导出 json @@ -186,8 +186,12 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean )} - {isOwner && splitDataLen > 0 && ( - {splitDataLen}条数据正在拆分,请耐心等待... + {isOwner && (splitDataQueue > 0 || embeddingQueue > 0) && ( + + {splitDataQueue > 0 ? `${splitDataQueue}条数据正在拆分,` : ''} + {embeddingQueue > 0 ? `${embeddingQueue}条数据正在生成索引,` : ''} + 请耐心等待... + )}