From dc1c1d135506b785256fc988610d9580e19885b1 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Fri, 26 May 2023 23:08:25 +0800 Subject: [PATCH] training queue --- .env.template | 6 +- docs/deploy/fastgpt/docker-compose.yml | 6 +- docs/deploy/mac.md | 4 +- docs/dev/README.md | 6 +- src/api/plugins/kb.ts | 14 +- src/constants/model.ts | 12 +- src/constants/plugin.ts | 12 +- src/pages/api/openapi/kb/appKbSearch.ts | 11 +- src/pages/api/openapi/kb/pushData.ts | 100 ++++---- src/pages/api/openapi/kb/updateData.ts | 27 ++- .../api/openapi/plugin/openaiEmbedding.ts | 13 +- src/pages/api/openapi/startEvents.ts | 19 -- src/pages/api/openapi/text/sensitiveCheck.ts | 2 +- src/pages/api/openapi/text/splitText.ts | 36 ++- .../api/plugins/kb/data/getTrainingData.ts | 59 +++-- src/pages/kb/components/DataCard.tsx | 17 +- src/pages/kb/components/SelectFileModal.tsx | 48 ++-- src/service/api/request.ts | 26 ++- src/service/events/generateQA.ts | 172 +++++++------- src/service/events/generateVector.ts | 216 ++++++++++-------- src/service/models/model.ts | 4 +- src/service/models/splitData.ts | 32 --- src/service/models/trainingData.ts | 38 +++ src/service/mongo.ts | 34 ++- src/service/pg.ts | 5 +- src/service/utils/auth.ts | 20 +- src/types/index.d.ts | 2 - src/types/model.d.ts | 2 +- src/types/mongoSchema.d.ts | 12 +- src/types/plugin.d.ts | 15 -- src/utils/file.ts | 49 ++-- src/utils/plugin/google.ts | 2 +- 32 files changed, 528 insertions(+), 493 deletions(-) delete mode 100644 src/pages/api/openapi/startEvents.ts delete mode 100644 src/service/models/splitData.ts create mode 100644 src/service/models/trainingData.ts diff --git a/.env.template b/.env.template index c8ddc38c1..f6de6c46c 100644 --- a/.env.template +++ b/.env.template @@ -1,9 +1,6 @@ # proxy # AXIOS_PROXY_HOST=127.0.0.1 # AXIOS_PROXY_PORT=7890 -# 是否开启队列任务。 1-开启,0-关闭(请求parentUrl去执行任务,单机时直接填1) -queueTask=1 -parentUrl=https://hostname/api/openapi/startEvents # email MY_MAIL=xxx@qq.com MAILE_CODE=xxx @@ -21,7 +18,8 @@ SENSITIVE_CHECK=1 # openai # OPENAI_BASE_URL=https://api.openai.com/v1 # OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉) -OPENAIKEY=sk-xxx +OPENAIKEY=sk-xxx # 对话用的key +OPENAI_TRAINING_KEY=sk-xxx # 训练用的key GPT4KEY=sk-xxx # claude CLAUDE_BASE_URL=calude模型请求地址 diff --git a/docs/deploy/fastgpt/docker-compose.yml b/docs/deploy/fastgpt/docker-compose.yml index 6dc9902bc..ce297bacb 100644 --- a/docs/deploy/fastgpt/docker-compose.yml +++ b/docs/deploy/fastgpt/docker-compose.yml @@ -39,9 +39,6 @@ services: # proxy(可选) - AXIOS_PROXY_HOST=127.0.0.1 - AXIOS_PROXY_PORT=7890 - # 是否开启队列任务。 1-开启,0-关闭(请求 parentUrl 去执行任务,单机时直接填1) - - queueTask=1 - - parentUrl=https://hostname/api/openapi/startEvents # 发送邮箱验证码配置。用的是QQ邮箱。参考 nodeMail 获取MAILE_CODE,自行百度。 - MY_MAIL=xxxx@qq.com - MAILE_CODE=xxxx @@ -66,7 +63,8 @@ services: - PG_PASSWORD=1234 # POSTGRES_PASSWORD - PG_DB_NAME=fastgpt # POSTGRES_DB # openai - - OPENAIKEY=sk-xxxxx + - OPENAIKEY=sk-xxxxx # 对话用的key + - OPENAI_TRAINING_KEY=sk-xxx # 训练用的key - GPT4KEY=sk-xxx - OPENAI_BASE_URL=https://api.openai.com/v1 - OPENAI_BASE_URL_AUTH=可选的安全凭证 diff --git a/docs/deploy/mac.md b/docs/deploy/mac.md index 5592418ef..0a05a46e2 100644 --- a/docs/deploy/mac.md +++ b/docs/deploy/mac.md @@ -36,7 +36,6 @@ mongo pg AXIOS_PROXY_HOST=127.0.0.1 AXIOS_PROXY_PORT_FAST=7890 AXIOS_PROXY_PORT_NORMAL=7890 -queueTask=1 # email MY_MAIL= {Your Mail} MAILE_CODE={Yoir Mail code} @@ -48,7 +47,8 @@ aliTemplateCode=SMS_xxx # token TOKEN_KEY=sswada # openai -OPENAIKEY={Your openapi key} +OPENAIKEY=sk-xxx # 对话用的key +OPENAI_TRAINING_KEY=sk-xxx # 训练用的key # db MONGODB_URI=mongodb://username:password@0.0.0.0:27017/test?authSource=admin PG_HOST=0.0.0.0 diff --git a/docs/dev/README.md b/docs/dev/README.md index d86163068..61aa1d9c7 100644 --- a/docs/dev/README.md +++ b/docs/dev/README.md @@ -10,9 +10,6 @@ # proxy(可选) AXIOS_PROXY_HOST=127.0.0.1 AXIOS_PROXY_PORT=7890 -# 是否开启队列任务。 1-开启,0-关闭(请求parentUrl去执行任务,单机时直接填1) -queueTask=1 -parentUrl=https://hostname/api/openapi/startEvents # email MY_MAIL=xxx@qq.com MAILE_CODE=xxx @@ -30,7 +27,8 @@ SENSITIVE_CHECK=1 # openai # OPENAI_BASE_URL=https://api.openai.com/v1 # OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉) -OPENAIKEY=sk-xxx +OPENAIKEY=sk-xxx # 对话用的key +OPENAI_TRAINING_KEY=sk-xxx # 训练用的key GPT4KEY=sk-xxx # claude CLAUDE_BASE_URL=calude模型请求地址 diff --git a/src/api/plugins/kb.ts b/src/api/plugins/kb.ts index e7d7ac4d7..00681fb88 100644 --- a/src/api/plugins/kb.ts +++ b/src/api/plugins/kb.ts @@ -1,7 +1,7 @@ import { GET, POST, PUT, DELETE } from '../request'; import type { KbItemType } from '@/types/plugin'; import { RequestPaging } from '@/types/index'; -import { SplitTextTypEnum } from '@/constants/plugin'; +import { TrainingTypeEnum } from '@/constants/plugin'; import { KbDataItemType } from '@/types/plugin'; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; @@ -34,11 +34,11 @@ export const getExportDataList = (kbId: string) => /** * 获取模型正在拆分数据的数量 */ -export const getTrainingData = (kbId: string) => - GET<{ - splitDataQueue: number; - embeddingQueue: number; - }>(`/plugins/kb/data/getTrainingData?kbId=${kbId}`); +export const getTrainingData = (data: { kbId: string; init: boolean }) => + POST<{ + qaListLen: number; + vectorListLen: number; + }>(`/plugins/kb/data/getTrainingData`, data); export const getKbDataItemById = (dataId: string) => GET(`/plugins/kb/data/getDataById`, { dataId }); @@ -69,5 +69,5 @@ export const postSplitData = (data: { kbId: string; chunks: string[]; prompt: string; - mode: `${SplitTextTypEnum}`; + mode: `${TrainingTypeEnum}`; }) => POST(`/openapi/text/splitText`, data); diff --git a/src/constants/model.ts b/src/constants/model.ts index 0971f2fee..945f2262a 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -108,27 +108,27 @@ export const ModelDataStatusMap: Record<`${ModelDataStatusEnum}`, string> = { /* 知识库搜索时的配置 */ // 搜索方式 -export enum ModelVectorSearchModeEnum { +export enum appVectorSearchModeEnum { hightSimilarity = 'hightSimilarity', // 高相似度+禁止回复 lowSimilarity = 'lowSimilarity', // 低相似度 noContext = 'noContex' // 高相似度+无上下文回复 } export const ModelVectorSearchModeMap: Record< - `${ModelVectorSearchModeEnum}`, + `${appVectorSearchModeEnum}`, { text: string; similarity: number; } > = { - [ModelVectorSearchModeEnum.hightSimilarity]: { + [appVectorSearchModeEnum.hightSimilarity]: { text: '高相似度, 无匹配时拒绝回复', similarity: 0.18 }, - [ModelVectorSearchModeEnum.noContext]: { + [appVectorSearchModeEnum.noContext]: { text: '高相似度,无匹配时直接回复', similarity: 0.18 }, - [ModelVectorSearchModeEnum.lowSimilarity]: { + [appVectorSearchModeEnum.lowSimilarity]: { text: '低相似度匹配', similarity: 0.7 } @@ -143,7 +143,7 @@ export const defaultModel: ModelSchema = { updateTime: Date.now(), chat: { relatedKbs: [], - searchMode: ModelVectorSearchModeEnum.hightSimilarity, + searchMode: appVectorSearchModeEnum.hightSimilarity, systemPrompt: '', temperature: 0, chatModel: OpenAiChatEnum.GPT35 diff --git a/src/constants/plugin.ts b/src/constants/plugin.ts index 5bc9ac312..0368090b4 100644 --- a/src/constants/plugin.ts +++ b/src/constants/plugin.ts @@ -1,14 +1,4 @@ -export enum SplitTextTypEnum { +export enum TrainingTypeEnum { 'qa' = 'qa', 'subsection' = 'subsection' } - -export enum PluginTypeEnum { - LLM = 'LLM', - Text = 'Text', - Function = 'Function' -} - -export enum PluginParamsTypeEnum { - 'Text' = 'text' -} diff --git a/src/pages/api/openapi/kb/appKbSearch.ts b/src/pages/api/openapi/kb/appKbSearch.ts index 97f01e577..40f5e4f85 100644 --- a/src/pages/api/openapi/kb/appKbSearch.ts +++ b/src/pages/api/openapi/kb/appKbSearch.ts @@ -5,7 +5,7 @@ import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; import type { ChatItemSimpleType } from '@/types/chat'; import type { ModelSchema } from '@/types/mongoSchema'; -import { ModelVectorSearchModeEnum } from '@/constants/model'; +import { appVectorSearchModeEnum } from '@/constants/model'; import { authModel } from '@/service/utils/auth'; import { ChatModelMap } from '@/constants/model'; import { ChatRoleEnum } from '@/constants/chat'; @@ -92,7 +92,8 @@ export async function appKbSearch({ // get vector const promptVectors = await openaiEmbedding({ userId, - input + input, + type: 'chat' }); // search kb @@ -138,7 +139,7 @@ export async function appKbSearch({ obj: ChatRoleEnum.System, value: model.chat.systemPrompt } - : model.chat.searchMode === ModelVectorSearchModeEnum.noContext + : model.chat.searchMode === appVectorSearchModeEnum.noContext ? { obj: ChatRoleEnum.System, value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` @@ -176,7 +177,7 @@ export async function appKbSearch({ const systemPrompt = sliceResult.flat().join('\n').trim(); /* 高相似度+不回复 */ - if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) { + if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) { return { code: 201, rawSearch: [], @@ -190,7 +191,7 @@ export async function appKbSearch({ }; } /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ - if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) { + if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) { return { code: 200, rawSearch: [], diff --git a/src/pages/api/openapi/kb/pushData.ts b/src/pages/api/openapi/kb/pushData.ts index bc8f2087e..0d34af332 100644 --- a/src/pages/api/openapi/kb/pushData.ts +++ b/src/pages/api/openapi/kb/pushData.ts @@ -1,84 +1,36 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import type { KbDataItemType } from '@/types/plugin'; import { jsonRes } from '@/service/response'; -import { connectToDatabase } from '@/service/mongo'; +import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; import { generateVector } from '@/service/events/generateVector'; -import { PgClient, insertKbItem } from '@/service/pg'; +import { PgClient } from '@/service/pg'; import { authKb } from '@/service/utils/auth'; import { withNextCors } from '@/service/utils/tools'; +interface Props { + kbId: string; + data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[]; +} + export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { - kbId, - data, - formatLineBreak = true - } = req.body as { - kbId: string; - formatLineBreak?: boolean; - data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[]; - }; + const { kbId, data } = req.body as Props; if (!kbId || !Array.isArray(data)) { throw new Error('缺少参数'); } - await connectToDatabase(); // 凭证校验 const { userId } = await authUser({ req }); - await authKb({ - userId, - kbId - }); - - // 过滤重复的内容 - const searchRes = await Promise.allSettled( - data.map(async ({ q, a = '' }) => { - if (!q) { - return Promise.reject('q为空'); - } - - if (formatLineBreak) { - q = q.replace(/\\n/g, '\n'); - a = a.replace(/\\n/g, '\n'); - } - - // Exactly the same data, not push - try { - const count = await PgClient.count('modelData', { - where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]] - }); - if (count > 0) { - return Promise.reject('已经存在'); - } - } catch (error) { - error; - } - return Promise.resolve({ - q, - a - }); - }) - ); - const filterData = searchRes - .filter((item) => item.status === 'fulfilled') - .map<{ q: string; a: string }>((item: any) => item.value); - - // 插入记录 - const insertRes = await insertKbItem({ - userId, - kbId, - data: filterData - }); - - generateVector(); - jsonRes(res, { - message: `共插入 ${insertRes.rowCount} 条数据`, - data: insertRes.rowCount + data: await pushDataToKb({ + kbId, + data, + userId + }) }); } catch (err) { jsonRes(res, { @@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex } }); +export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) { + await authKb({ + userId, + kbId + }); + + if (data.length === 0) { + return { + trainingId: '' + }; + } + + // 插入记录 + const { _id } = await TrainingData.create({ + userId, + kbId, + vectorList: data + }); + + generateVector(_id); + + return { + trainingId: _id + }; +} + export const config = { api: { bodyParser: { diff --git a/src/pages/api/openapi/kb/updateData.ts b/src/pages/api/openapi/kb/updateData.ts index 07db9f4a0..5a7402846 100644 --- a/src/pages/api/openapi/kb/updateData.ts +++ b/src/pages/api/openapi/kb/updateData.ts @@ -5,10 +5,11 @@ import { ModelDataStatusEnum } from '@/constants/model'; import { generateVector } from '@/service/events/generateVector'; import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; +import { openaiEmbedding } from '../plugin/openaiEmbedding'; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string }; + const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string }; if (!dataId) { throw new Error('缺少参数'); @@ -17,22 +18,24 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 凭证校验 const { userId } = await authUser({ req }); + // get vector + const vector = await (async () => { + if (q) { + return openaiEmbedding({ + userId, + input: [q], + type: 'chat' + }); + } + return []; + })(); + // 更新 pg 内容.仅修改a,不需要更新向量。 await PgClient.update('modelData', { where: [['id', dataId], 'AND', ['user_id', userId]], - values: [ - { key: 'a', value: a }, - ...(q - ? [ - { key: 'q', value: q }, - { key: 'status', value: ModelDataStatusEnum.waiting } - ] - : []) - ] + values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])] }); - q && generateVector(); - jsonRes(res); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/openapi/plugin/openaiEmbedding.ts b/src/pages/api/openapi/plugin/openaiEmbedding.ts index d1bb776ba..6952ffd82 100644 --- a/src/pages/api/openapi/plugin/openaiEmbedding.ts +++ b/src/pages/api/openapi/plugin/openaiEmbedding.ts @@ -1,30 +1,31 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { authUser } from '@/service/utils/auth'; -import { PgClient } from '@/service/pg'; import { withNextCors } from '@/service/utils/tools'; import { getApiKey } from '@/service/utils/auth'; import { getOpenAIApi } from '@/service/utils/chat/openai'; import { embeddingModel } from '@/constants/model'; import { axiosConfig } from '@/service/utils/tools'; import { pushGenerateVectorBill } from '@/service/events/pushBill'; +import { ApiKeyType } from '@/service/utils/auth'; type Props = { input: string[]; + type?: ApiKeyType; }; type Response = number[][]; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { const { userId } = await authUser({ req }); - let { input } = req.query as Props; + let { input, type } = req.query as Props; if (!Array.isArray(input)) { throw new Error('缺少参数'); } jsonRes(res, { - data: await openaiEmbedding({ userId, input, mustPay: true }) + data: await openaiEmbedding({ userId, input, mustPay: true, type }) }); } catch (err) { console.log(err); @@ -38,12 +39,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex export async function openaiEmbedding({ userId, input, - mustPay = false + mustPay = false, + type = 'chat' }: { userId: string; mustPay?: boolean } & Props) { const { userOpenAiKey, systemAuthKey } = await getApiKey({ model: 'gpt-3.5-turbo', userId, - mustPay + mustPay, + type }); // 获取 chatAPI diff --git a/src/pages/api/openapi/startEvents.ts b/src/pages/api/openapi/startEvents.ts deleted file mode 100644 index b0bed8a7c..000000000 --- a/src/pages/api/openapi/startEvents.ts +++ /dev/null @@ -1,19 +0,0 @@ -// Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { generateQA } from '@/service/events/generateQA'; -import { generateVector } from '@/service/events/generateVector'; - -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - generateQA(); - generateVector(); - - jsonRes(res); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -} diff --git a/src/pages/api/openapi/text/sensitiveCheck.ts b/src/pages/api/openapi/text/sensitiveCheck.ts index c9921040b..952900735 100644 --- a/src/pages/api/openapi/text/sensitiveCheck.ts +++ b/src/pages/api/openapi/text/sensitiveCheck.ts @@ -17,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { input } = req.body as TextPluginRequestParams; const response = await axios({ - ...axiosConfig(getSystemOpenAiKey()), + ...axiosConfig(getSystemOpenAiKey('chat')), method: 'POST', url: `/moderations`, data: { diff --git a/src/pages/api/openapi/text/splitText.ts b/src/pages/api/openapi/text/splitText.ts index 489cfad77..642730d28 100644 --- a/src/pages/api/openapi/text/splitText.ts +++ b/src/pages/api/openapi/text/splitText.ts @@ -1,12 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, SplitData } from '@/service/mongo'; +import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authKb, authUser } from '@/service/utils/auth'; -import { generateVector } from '@/service/events/generateVector'; import { generateQA } from '@/service/events/generateQA'; -import { insertKbItem } from '@/service/pg'; -import { SplitTextTypEnum } from '@/constants/plugin'; +import { TrainingTypeEnum } from '@/constants/plugin'; import { withNextCors } from '@/service/utils/tools'; +import { pushDataToKb } from '../kb/pushData'; /* split text */ export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -15,7 +14,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex kbId: string; chunks: string[]; prompt: string; - mode: `${SplitTextTypEnum}`; + mode: `${TrainingTypeEnum}`; }; if (!chunks || !kbId || !prompt) { throw new Error('参数错误'); @@ -30,29 +29,26 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex userId }); - if (mode === SplitTextTypEnum.qa) { + if (mode === TrainingTypeEnum.qa) { // 批量QA拆分插入数据 - await SplitData.create({ + const { _id } = await TrainingData.create({ userId, kbId, - textList: chunks, + qaList: chunks, prompt }); - - generateQA(); - } else if (mode === SplitTextTypEnum.subsection) { - // 待优化,直接调用另一个接口 - // 插入记录 - await insertKbItem({ - userId, + generateQA(_id); + } else if (mode === TrainingTypeEnum.subsection) { + // 分段导入,直接插入向量队列 + const response = await pushDataToKb({ kbId, - data: chunks.map((item) => ({ - q: item, - a: '' - })) + data: chunks.map((item) => ({ q: item, a: '' })), + userId }); - generateVector(); + return jsonRes(res, { + data: response + }); } jsonRes(res); diff --git a/src/pages/api/plugins/kb/data/getTrainingData.ts b/src/pages/api/plugins/kb/data/getTrainingData.ts index c1dc224e8..d7369f90b 100644 --- a/src/pages/api/plugins/kb/data/getTrainingData.ts +++ b/src/pages/api/plugins/kb/data/getTrainingData.ts @@ -1,14 +1,15 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, SplitData, Model } from '@/service/mongo'; +import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; -import { ModelDataStatusEnum } from '@/constants/model'; -import { PgClient } from '@/service/pg'; +import { Types } from 'mongoose'; +import { generateQA } from '@/service/events/generateQA'; +import { generateVector } from '@/service/events/generateVector'; /* 拆分数据成QA */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { kbId } = req.query as { kbId: string }; + const { kbId, init = false } = req.body as { kbId: string; init: boolean }; if (!kbId) { throw new Error('参数错误'); } @@ -17,29 +18,43 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { userId } = await authUser({ req, authToken: true }); // split queue data - const data = await SplitData.find({ - userId, - kbId, - textList: { $exists: true, $not: { $size: 0 } } - }); - - // embedding queue data - const embeddingData = await PgClient.count('modelData', { - where: [ - ['user_id', userId], - 'AND', - ['kb_id', kbId], - 'AND', - ['status', ModelDataStatusEnum.waiting] - ] - }); + const result = await TrainingData.aggregate([ + { $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } }, + { + $project: { + qaListLength: { $size: { $ifNull: ['$qaList', []] } }, + vectorListLength: { $size: { $ifNull: ['$vectorList', []] } } + } + }, + { + $group: { + _id: null, + totalQaListLength: { $sum: '$qaListLength' }, + totalVectorListLength: { $sum: '$vectorListLength' } + } + } + ]); jsonRes(res, { data: { - splitDataQueue: data.map((item) => item.textList).flat().length, - embeddingQueue: embeddingData + qaListLen: result[0]?.totalQaListLength || 0, + vectorListLen: result[0]?.totalVectorListLength || 0 } }); + + if (init) { + const list = await TrainingData.find( + { + userId, + kbId + }, + '_id' + ); + list.forEach((item) => { + generateQA(item._id); + generateVector(item._id); + }); + } } catch (err) { jsonRes(res, { code: 500, diff --git a/src/pages/kb/components/DataCard.tsx b/src/pages/kb/components/DataCard.tsx index 6747c3f16..b4ce8343b 100644 --- a/src/pages/kb/components/DataCard.tsx +++ b/src/pages/kb/components/DataCard.tsx @@ -91,9 +91,9 @@ const DataCard = ({ kbId }: { kbId: string }) => { onClose: onCloseSelectCsvModal } = useDisclosure(); - const { data: { splitDataQueue = 0, embeddingQueue = 0 } = {}, refetch } = useQuery( + const { data: { qaListLen = 0, vectorListLen = 0 } = {}, refetch } = useQuery( ['getModelSplitDataList'], - () => getTrainingData(kbId), + () => getTrainingData({ kbId, init: false }), { onError(err) { console.log(err); @@ -113,7 +113,7 @@ const DataCard = ({ kbId }: { kbId: string }) => { // interval get data useQuery(['refetchData'], () => refetchData(pageNum), { refetchInterval: 5000, - enabled: splitDataQueue > 0 || embeddingQueue > 0 + enabled: qaListLen > 0 || vectorListLen > 0 }); // get al data and export csv @@ -161,7 +161,10 @@ const DataCard = ({ kbId }: { kbId: string }) => { variant={'outline'} mr={[2, 4]} size={'sm'} - onClick={() => refetchData(pageNum)} + onClick={() => { + refetchData(pageNum); + getTrainingData({ kbId, init: true }); + }} />