From 516618b0cd2a44b295c8219ed68d27300b9bafbe Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sun, 28 May 2023 20:13:19 +0800 Subject: [PATCH] feat: insert data de-weight;perf: input queue --- src/api/plugins/kb.ts | 8 +- src/pages/api/openapi/kb/pushData.ts | 87 ++++++++++++--------- src/pages/api/openapi/kb/updateData.ts | 4 +- src/pages/kb/components/InputDataModal.tsx | 26 +++--- src/pages/kb/components/SelectCsvModal.tsx | 47 ++++++----- src/pages/kb/components/SelectFileModal.tsx | 57 ++++++++++---- src/service/errorCode.ts | 8 +- src/service/events/generateQA.ts | 18 +++-- src/service/events/generateVector.ts | 15 +++- src/service/events/pushBill.ts | 6 +- src/service/pg.ts | 4 +- src/service/response.ts | 12 ++- 12 files changed, 187 insertions(+), 105 deletions(-) diff --git a/src/api/plugins/kb.ts b/src/api/plugins/kb.ts index 9cae025d8..e6ccb8c93 100644 --- a/src/api/plugins/kb.ts +++ b/src/api/plugins/kb.ts @@ -2,7 +2,10 @@ import { GET, POST, PUT, DELETE } from '../request'; import type { KbItemType } from '@/types/plugin'; import { RequestPaging } from '@/types/index'; import { TrainingModeEnum } from '@/constants/plugin'; -import { Props as PushDataProps } from '@/pages/api/openapi/kb/pushData'; +import { + Props as PushDataProps, + Response as PushDateResponse +} from '@/pages/api/openapi/kb/pushData'; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; @@ -46,7 +49,8 @@ export const getKbDataItemById = (dataId: string) => /** * 直接push数据 */ -export const postKbDataFromList = (data: PushDataProps) => POST(`/openapi/kb/pushData`, data); +export const postKbDataFromList = (data: PushDataProps) => + POST(`/openapi/kb/pushData`, data); /** * 更新一条数据 diff --git a/src/pages/api/openapi/kb/pushData.ts b/src/pages/api/openapi/kb/pushData.ts index 9dd950a7d..617c14a2b 100644 --- a/src/pages/api/openapi/kb/pushData.ts +++ b/src/pages/api/openapi/kb/pushData.ts @@ -16,6 +16,10 @@ export type Props = { prompt?: string; }; +export type Response = { + insertLen: number; +}; + export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { const { kbId, data, mode, prompt } = req.body as Props; @@ -28,7 +32,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 凭证校验 const { userId } = await authUser({ req }); - jsonRes(res, { + jsonRes(res, { data: await pushDataToKb({ kbId, data, @@ -51,16 +55,12 @@ export async function pushDataToKb({ data, mode, prompt -}: { userId: string } & Props) { +}: { userId: string } & Props): Promise { await authKb({ userId, kbId }); - if (data.length === 0) { - return {}; - } - // 过滤重复的 qa 内容 const set = new Set(); const filterData: { @@ -75,41 +75,54 @@ export async function pushDataToKb({ set.add(text); } }); + // 数据库去重 - // const searchRes = await Promise.allSettled( - // data.map(async ({ q, a = '' }) => { - // if (!q) { - // return Promise.reject('q为空'); - // } + const insertData = ( + await Promise.allSettled( + filterData.map(async ({ q, a = '' }) => { + if (mode !== TrainingModeEnum.index) { + return Promise.resolve({ + q, + a + }); + } - // q = q.replace(/\\n/g, '\n'); - // a = a.replace(/\\n/g, '\n'); + if (!q) { + return Promise.reject('q为空'); + } - // // Exactly the same data, not push - // try { - // const count = await PgClient.count('modelData', { - // where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]] - // }); + q = q.replace(/\\n/g, '\n').trim().replace(/'/g, '"'); + a = a.replace(/\\n/g, '\n').trim().replace(/'/g, '"'); - // if (count > 0) { - // return Promise.reject('已经存在'); - // } - // } catch (error) { - // error; - // } - // return Promise.resolve({ - // q, - // a - // }); - // }) - // ); - // const filterData = searchRes - // .filter((item) => item.status === 'fulfilled') - // .map<{ q: string; a: string }>((item: any) => item.value); + // Exactly the same data, not push + try { + const { rows } = await PgClient.query(` + SELECT COUNT(*) > 0 AS exists + FROM modelData + WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}' + `); + const exists = rows[0]?.exists || false; + + if (exists) { + return Promise.reject('已经存在'); + } + } catch (error) { + console.log(error); + error; + } + return Promise.resolve({ + q, + a + }); + }) + ) + ) + .filter((item) => item.status === 'fulfilled') + .map<{ q: string; a: string }>((item: any) => item.value); // 插入记录 await TrainingData.insertMany( - data.map((item) => ({ + insertData.map((item) => ({ q: item.q, a: item.a, userId, @@ -119,9 +132,11 @@ export async function pushDataToKb({ })) ); - startQueue(); + insertData.length > 0 && startQueue(); - return {}; + return { + insertLen: insertData.length + }; } export const config = { diff --git a/src/pages/api/openapi/kb/updateData.ts b/src/pages/api/openapi/kb/updateData.ts index ba53b8060..65a5724b5 100644 --- a/src/pages/api/openapi/kb/updateData.ts +++ b/src/pages/api/openapi/kb/updateData.ts @@ -32,10 +32,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex await PgClient.update('modelData', { where: [['id', dataId], 'AND', ['user_id', userId]], values: [ - { key: 'a', value: a }, + { key: 'a', value: a.replace(/'/g, '"') }, ...(q ? [ - { key: 'q', value: q }, + { key: 'q', value: q.replace(/'/g, '"') }, { key: 'vector', value: `[${vector[0]}]` } ] : []) diff --git a/src/pages/kb/components/InputDataModal.tsx b/src/pages/kb/components/InputDataModal.tsx index c74b2ff4f..59b35781a 100644 --- a/src/pages/kb/components/InputDataModal.tsx +++ b/src/pages/kb/components/InputDataModal.tsx @@ -54,7 +54,7 @@ const InputDataModal = ({ setLoading(true); try { - const res = await postKbDataFromList({ + const { insertLen } = await postKbDataFromList({ kbId, data: [ { @@ -65,14 +65,22 @@ const InputDataModal = ({ mode: TrainingModeEnum.index }); - toast({ - title: res === 0 ? '可能已存在完全一致的数据' : '导入数据成功,需要一段时间训练', - status: 'success' - }); - reset({ - a: '', - q: '' - }); + if (insertLen === 0) { + toast({ + title: '已存在完全一致的数据', + status: 'warning' + }); + } else { + toast({ + title: '导入数据成功,需要一段时间训练', + status: 'success' + }); + reset({ + a: '', + q: '' + }); + } + onSuccess(); } catch (err: any) { toast({ diff --git a/src/pages/kb/components/SelectCsvModal.tsx b/src/pages/kb/components/SelectCsvModal.tsx index 447a84ac3..a5075d98d 100644 --- a/src/pages/kb/components/SelectCsvModal.tsx +++ b/src/pages/kb/components/SelectCsvModal.tsx @@ -37,6 +37,7 @@ const SelectJsonModal = ({ const { toast } = useToast(); const { File, onOpen } = useSelectFile({ fileType: '.csv', multiple: false }); const [fileData, setFileData] = useState<{ q: string; a: string }[]>([]); + const [successData, setSuccessData] = useState(0); const { openConfirm, ConfirmChild } = useConfirm({ content: '确认导入该数据集?' }); @@ -67,27 +68,35 @@ const SelectJsonModal = ({ [setSelecting, toast] ); - const { mutate, isLoading } = useMutation({ + const { mutate, isLoading: uploading } = useMutation({ mutationFn: async () => { if (!fileData || fileData.length === 0) return; - const res = await postKbDataFromList({ - kbId, - data: fileData, - mode: TrainingModeEnum.index - }); + let success = 0; + + // subsection import + const step = 50; + for (let i = 0; i < fileData.length; i += step) { + const { insertLen } = await postKbDataFromList({ + kbId, + data: fileData.slice(i, i + step), + mode: TrainingModeEnum.index + }); + success += insertLen || 0; + setSuccessData((state) => state + step); + } toast({ - title: `导入数据成功,最终导入: ${res || 0} 条数据。需要一段时间训练`, + title: `导入数据成功,最终导入: ${success} 条数据。需要一段时间训练`, status: 'success', duration: 4000 }); onClose(); onSuccess(); }, - onError() { + onError(err) { toast({ - title: '导入文件失败', + title: getErrText(err, '导入文件失败'), status: 'error' }); } @@ -121,15 +130,15 @@ const SelectJsonModal = ({ 点击下载csv模板 - - 一共 {fileData.length} 组数据 + 一共 {fileData.length} 组数据(下面最多展示100组) - {fileData.map((item, index) => ( + {fileData.slice(0, 100).map((item, index) => ( Q{index + 1}. {item.q} @@ -144,15 +153,15 @@ const SelectJsonModal = ({ - - diff --git a/src/pages/kb/components/SelectFileModal.tsx b/src/pages/kb/components/SelectFileModal.tsx index 6706dfda9..fb103be73 100644 --- a/src/pages/kb/components/SelectFileModal.tsx +++ b/src/pages/kb/components/SelectFileModal.tsx @@ -55,9 +55,14 @@ const SelectFileModal = ({ const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const [mode, setMode] = useState<`${TrainingModeEnum}`>(TrainingModeEnum.index); const [fileTextArr, setFileTextArr] = useState(['']); - const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({ + const [splitRes, setSplitRes] = useState<{ + tokens: number; + chunks: string[]; + successChunks: number; + }>({ tokens: 0, - chunks: [] + chunks: [], + successChunks: 0 }); const { openConfirm, ConfirmChild } = useConfirm({ content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${ @@ -104,19 +109,30 @@ const SelectFileModal = ({ [toast] ); - const { mutate, isLoading } = useMutation({ + const { mutate, isLoading: uploading } = useMutation({ mutationFn: async () => { if (splitRes.chunks.length === 0) return; - await postKbDataFromList({ - kbId, - data: splitRes.chunks.map((text) => ({ q: text, a: '' })), - prompt: `下面是"${prompt || '一段长文本'}"`, - mode - }); + // subsection import + let success = 0; + const step = 50; + for (let i = 0; i < splitRes.chunks.length; i += step) { + const { insertLen } = await postKbDataFromList({ + kbId, + data: splitRes.chunks.slice(i, i + step).map((text) => ({ q: text, a: '' })), + prompt: `下面是"${prompt || '一段长文本'}"`, + mode + }); + + success += insertLen; + setSplitRes((state) => ({ + ...state, + successChunks: state.successChunks + step + })); + } toast({ - title: '导入数据成功,需要一段拆解和训练. 重复数据会自动删除', + title: `去重后共导入 ${success} 条数据,需要一段拆解和训练.`, status: 'success' }); onClose(); @@ -148,7 +164,8 @@ const SelectFileModal = ({ setSplitRes({ tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0), - chunks: splitRes.map((item) => item.chunks).flat() + chunks: splitRes.map((item) => item.chunks).flat(), + successChunks: 0 }); await promise; @@ -235,6 +252,11 @@ const SelectFileModal = ({ ...fileTextArr.slice(i + 1) ]); }} + onBlur={(e) => { + if (fileTextArr.length > 1 && e.target.value === '') { + setFileTextArr((state) => [...state.slice(0, i), ...state.slice(i + 1)]); + } + }} /> ))} @@ -242,19 +264,22 @@ const SelectFileModal = ({ - - diff --git a/src/service/errorCode.ts b/src/service/errorCode.ts index 2585108d7..7236c691b 100644 --- a/src/service/errorCode.ts +++ b/src/service/errorCode.ts @@ -24,10 +24,10 @@ export const openaiError: Record = { 'Bad Request': 'Bad Request~ 可能内容太多了', 'Bad Gateway': '网关异常,请重试' }; -export const openaiError2: Record = { - insufficient_quota: 'API 余额不足', - billing_not_active: 'openai 账号异常', - invalid_request_error: '无效的 openai 请求' +export const openaiAccountError: Record = { + // insufficient_quota: 'API 余额不足', + invalid_api_key: 'openai 账号异常' + // invalid_request_error: '无效的 openai 请求' }; export const proxyError: Record = { ECONNABORTED: true, diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index c239403bb..3c4ef752f 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -2,7 +2,7 @@ import { TrainingData } from '@/service/mongo'; import { getApiKey } from '../utils/auth'; import { OpenAiChatEnum } from '@/constants/model'; import { pushSplitDataBill } from '@/service/events/pushBill'; -import { openaiError2 } from '../errorCode'; +import { openaiAccountError } from '../errorCode'; import { modelServiceToolMap } from '../utils/chat'; import { ChatRoleEnum } from '@/constants/chat'; import { BillTypeEnum } from '@/constants/user'; @@ -81,8 +81,6 @@ export async function generateQA(): Promise { type: 'training' }); - console.log(`正在生成一组QA。ID: ${trainingId}`); - const startTime = Date.now(); // 请求 chatgpt 获取回答 @@ -137,7 +135,7 @@ A2: const responseList = response.map((item) => item.result).flat(); // 创建 向量生成 队列 - pushDataToKb({ + await pushDataToKb({ kbId, data: responseList, userId, @@ -161,8 +159,16 @@ A2: console.log('生成QA错误:', err); } - // openai 账号异常或者账号余额不足,删除任务 - if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { + // message error or openai account error + if ( + err?.message === 'invalid message format' || + openaiAccountError[err?.response?.data?.error?.code] + ) { + await TrainingData.findByIdAndRemove(trainingId); + } + + // 账号余额不足,删除任务 + if (err === ERROR_ENUM.insufficientQuota) { console.log('余额不足,删除向量生成任务'); await TrainingData.deleteMany({ userId diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 9c77581c0..ed390027d 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -1,4 +1,4 @@ -import { openaiError2 } from '../errorCode'; +import { openaiAccountError } from '../errorCode'; import { insertKbItem } from '@/service/pg'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { TrainingData } from '../models/trainingData'; @@ -111,8 +111,17 @@ export async function generateVector(): Promise { console.log('生成向量错误:', err); } - // openai 账号异常或者账号余额不足,删除任务 - if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { + // message error or openai account error + if ( + err?.message === 'invalid message format' || + openaiAccountError[err?.response?.data?.error?.code] + ) { + console.log('删除一个任务'); + await TrainingData.findByIdAndRemove(trainingId); + } + + // 账号余额不足,删除任务 + if (err === ERROR_ENUM.insufficientQuota) { console.log('余额不足,删除向量生成任务'); await TrainingData.deleteMany({ userId diff --git a/src/service/events/pushBill.ts b/src/service/events/pushBill.ts index b63fbb589..d6b070077 100644 --- a/src/service/events/pushBill.ts +++ b/src/service/events/pushBill.ts @@ -134,9 +134,9 @@ export const pushGenerateVectorBill = async ({ text: string; tokenLen: number; }) => { - console.log( - `vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}` - ); + // console.log( + // `vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}` + // ); if (!isPay) return; let billId; diff --git a/src/service/pg.ts b/src/service/pg.ts index 5f0b42f2f..2ac25a193 100644 --- a/src/service/pg.ts +++ b/src/service/pg.ts @@ -177,8 +177,8 @@ export const insertKbItem = ({ values: data.map((item) => [ { key: 'user_id', value: userId }, { key: 'kb_id', value: kbId }, - { key: 'q', value: item.q }, - { key: 'a', value: item.a }, + { key: 'q', value: item.q.replace(/'/g, '"') }, + { key: 'a', value: item.a.replace(/'/g, '"') }, { key: 'vector', value: `[${item.vector}]` } ]) }); diff --git a/src/service/response.ts b/src/service/response.ts index 2dce57220..e57ab5dd1 100644 --- a/src/service/response.ts +++ b/src/service/response.ts @@ -1,5 +1,11 @@ import { NextApiResponse } from 'next'; -import { openaiError, openaiError2, proxyError, ERROR_RESPONSE, ERROR_ENUM } from './errorCode'; +import { + openaiError, + openaiAccountError, + proxyError, + ERROR_RESPONSE, + ERROR_ENUM +} from './errorCode'; import { clearCookie } from './utils/tools'; export interface ResponseType { @@ -40,8 +46,8 @@ export const jsonRes = ( msg = '接口连接异常'; } else if (error?.response?.data?.error?.message) { msg = error?.response?.data?.error?.message; - } else if (openaiError2[error?.response?.data?.error?.type]) { - msg = openaiError2[error?.response?.data?.error?.type]; + } else if (openaiAccountError[error?.response?.data?.error?.code]) { + msg = openaiAccountError[error?.response?.data?.error?.code]; } else if (openaiError[error?.response?.statusText]) { msg = openaiError[error.response.statusText]; }