perf: generate queue

This commit is contained in:
archer
2023-05-27 04:38:00 +08:00
parent f05b12975c
commit 741381ecb0
19 changed files with 288 additions and 265 deletions

View File

@@ -2,7 +2,7 @@ import { GET, POST, PUT, DELETE } from '../request';
import type { KbItemType } from '@/types/plugin'; import type { KbItemType } from '@/types/plugin';
import { RequestPaging } from '@/types/index'; import { RequestPaging } from '@/types/index';
import { TrainingTypeEnum } from '@/constants/plugin'; import { TrainingTypeEnum } from '@/constants/plugin';
import { KbDataItemType } from '@/types/plugin'; import { Props as PushDataProps } from '@/pages/api/openapi/kb/pushData';
export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string };
@@ -46,10 +46,7 @@ export const getKbDataItemById = (dataId: string) =>
/** /**
* 直接push数据 * 直接push数据
*/ */
export const postKbDataFromList = (data: { export const postKbDataFromList = (data: PushDataProps) => POST(`/openapi/kb/pushData`, data);
kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
}) => POST(`/openapi/kb/pushData`, data);
/** /**
* 更新一条数据 * 更新一条数据
@@ -70,4 +67,4 @@ export const postSplitData = (data: {
chunks: string[]; chunks: string[];
prompt: string; prompt: string;
mode: `${TrainingTypeEnum}`; mode: `${TrainingTypeEnum}`;
}) => POST(`/openapi/text/splitText`, data); }) => POST(`/openapi/text/pushData`, data);

View File

@@ -1,4 +1,8 @@
export enum TrainingTypeEnum { export enum TrainingTypeEnum {
'qa' = 'qa', 'qa' = 'qa',
'subsection' = 'subsection' 'index' = 'index'
} }
export const TrainingTypeMap = {
[TrainingTypeEnum.qa]: 'qa',
[TrainingTypeEnum.index]: 'index'
};

View File

@@ -0,0 +1,37 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { TrainingTypeEnum } from '@/constants/plugin';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
await authUser({ req, authRoot: true });
await connectToDatabase();
// split queue data
const result = await TrainingData.aggregate([
{
$group: {
_id: '$mode',
count: { $sum: 1 }
}
}
]);
jsonRes(res, {
data: {
qaListLen: result.find((item) => item._id === TrainingTypeEnum.qa)?.count || 0,
vectorListLen: result.find((item) => item._id === TrainingTypeEnum.index)?.count || 0
}
});
} catch (error) {
console.log(error);
jsonRes(res, {
code: 500,
error
});
}
}

View File

@@ -3,19 +3,21 @@ import type { KbDataItemType } from '@/types/plugin';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo'; import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector';
import { PgClient } from '@/service/pg';
import { authKb } from '@/service/utils/auth'; import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
import { TrainingTypeEnum } from '@/constants/plugin';
import { startQueue } from '@/service/utils/tools';
interface Props { export type Props = {
kbId: string; kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[]; data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
} mode: `${TrainingTypeEnum}`;
prompt?: string;
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { kbId, data } = req.body as Props; const { kbId, data, mode, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data)) { if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数'); throw new Error('缺少参数');
@@ -29,7 +31,9 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
data: await pushDataToKb({ data: await pushDataToKb({
kbId, kbId,
data, data,
userId userId,
mode,
prompt
}) })
}); });
} catch (err) { } catch (err) {
@@ -40,36 +44,43 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
} }
}); });
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) { export async function pushDataToKb({
userId,
kbId,
data,
mode,
prompt
}: { userId: string } & Props) {
await authKb({ await authKb({
userId, userId,
kbId kbId
}); });
if (data.length === 0) { if (data.length === 0) {
return { return {};
trainingId: ''
};
} }
// 插入记录 // 插入记录
const { _id } = await TrainingData.create({ await TrainingData.insertMany(
userId, data.map((item) => ({
kbId, q: item.q,
vectorList: data a: item.a,
}); userId,
kbId,
mode,
prompt
}))
);
generateVector(_id); startQueue();
return { return {};
trainingId: _id
};
} }
export const config = { export const config = {
api: { api: {
bodyParser: { bodyParser: {
sizeLimit: '100mb' sizeLimit: '20mb'
} }
} }
}; };

View File

@@ -33,7 +33,15 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 更新 pg 内容.仅修改a不需要更新向量。 // 更新 pg 内容.仅修改a不需要更新向量。
await PgClient.update('modelData', { await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]], where: [['id', dataId], 'AND', ['user_id', userId]],
values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])] values: [
{ key: 'a', value: a },
...(q
? [
{ key: 'q', value: q },
{ key: 'vector', value: `[${vector[0]}]` }
]
: [])
]
}); });
jsonRes(res); jsonRes(res);

View File

@@ -1,69 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authKb, authUser } from '@/service/utils/auth';
import { generateQA } from '@/service/events/generateQA';
import { TrainingTypeEnum } from '@/constants/plugin';
import { withNextCors } from '@/service/utils/tools';
import { pushDataToKb } from '../kb/pushData';
/* split text */
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { chunks, kbId, prompt, mode } = req.body as {
kbId: string;
chunks: string[];
prompt: string;
mode: `${TrainingTypeEnum}`;
};
if (!chunks || !kbId || !prompt) {
throw new Error('参数错误');
}
await connectToDatabase();
const { userId } = await authUser({ req });
// 验证是否是该用户的 model
await authKb({
kbId,
userId
});
if (mode === TrainingTypeEnum.qa) {
// 批量QA拆分插入数据
const { _id } = await TrainingData.create({
userId,
kbId,
qaList: chunks,
prompt
});
generateQA(_id);
} else if (mode === TrainingTypeEnum.subsection) {
// 分段导入,直接插入向量队列
const response = await pushDataToKb({
kbId,
data: chunks.map((item) => ({ q: item, a: '' })),
userId
});
return jsonRes(res, {
data: response
});
}
jsonRes(res);
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
});
export const config = {
api: {
bodyParser: {
sizeLimit: '100mb'
}
}
};

View File

@@ -2,9 +2,10 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo'; import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { Types } from 'mongoose';
import { generateQA } from '@/service/events/generateQA'; import { generateQA } from '@/service/events/generateQA';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { TrainingTypeEnum } from '@/constants/plugin';
import { Types } from 'mongoose';
/* 拆分数据成QA */ /* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -19,26 +20,24 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// split queue data // split queue data
const result = await TrainingData.aggregate([ const result = await TrainingData.aggregate([
{ $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } },
{ {
$project: { $match: {
qaListLength: { $size: { $ifNull: ['$qaList', []] } }, userId: new Types.ObjectId(userId),
vectorListLength: { $size: { $ifNull: ['$vectorList', []] } } kbId: new Types.ObjectId(kbId)
} }
}, },
{ {
$group: { $group: {
_id: null, _id: '$mode',
totalQaListLength: { $sum: '$qaListLength' }, count: { $sum: 1 }
totalVectorListLength: { $sum: '$vectorListLength' }
} }
} }
]); ]);
jsonRes(res, { jsonRes(res, {
data: { data: {
qaListLen: result[0]?.totalQaListLength || 0, qaListLen: result.find((item) => item._id === TrainingTypeEnum.qa)?.count || 0,
vectorListLen: result[0]?.totalVectorListLength || 0 vectorListLen: result.find((item) => item._id === TrainingTypeEnum.index)?.count || 0
} }
}); });
@@ -49,10 +48,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
kbId kbId
}, },
'_id' '_id'
); ).limit(10);
list.forEach((item) => { list.forEach((item) => {
generateQA(item._id); generateQA();
generateVector(item._id); generateVector();
}); });
} }
} catch (err) { } catch (err) {

View File

@@ -13,6 +13,7 @@ import {
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { postKbDataFromList, putKbDataById } from '@/api/plugins/kb'; import { postKbDataFromList, putKbDataById } from '@/api/plugins/kb';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { TrainingTypeEnum } from '@/constants/plugin';
export type FormData = { dataId?: string; a: string; q: string }; export type FormData = { dataId?: string; a: string; q: string };
@@ -59,7 +60,8 @@ const InputDataModal = ({
a: e.a, a: e.a,
q: e.q q: e.q
} }
] ],
mode: TrainingTypeEnum.index
}); });
toast({ toast({

View File

@@ -19,6 +19,7 @@ import { postKbDataFromList } from '@/api/plugins/kb';
import Markdown from '@/components/Markdown'; import Markdown from '@/components/Markdown';
import { useMarkdown } from '@/hooks/useMarkdown'; import { useMarkdown } from '@/hooks/useMarkdown';
import { fileDownload } from '@/utils/file'; import { fileDownload } from '@/utils/file';
import { TrainingTypeEnum } from '@/constants/plugin';
const csvTemplate = `question,answer\n"什么是 laf","laf 是一个云函数开发平台……"\n"什么是 sealos","Sealos 是以 kubernetes 为内核的云操作系统发行版,可以……"`; const csvTemplate = `question,answer\n"什么是 laf","laf 是一个云函数开发平台……"\n"什么是 sealos","Sealos 是以 kubernetes 为内核的云操作系统发行版,可以……"`;
@@ -72,7 +73,8 @@ const SelectJsonModal = ({
const res = await postKbDataFromList({ const res = await postKbDataFromList({
kbId, kbId,
data: fileData data: fileData,
mode: TrainingTypeEnum.index
}); });
toast({ toast({

View File

@@ -17,7 +17,7 @@ import { useSelectFile } from '@/hooks/useSelectFile';
import { useConfirm } from '@/hooks/useConfirm'; import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/file'; import { readTxtContent, readPdfContent, readDocContent } from '@/utils/file';
import { useMutation } from '@tanstack/react-query'; import { useMutation } from '@tanstack/react-query';
import { postSplitData } from '@/api/plugins/kb'; import { postKbDataFromList } from '@/api/plugins/kb';
import Radio from '@/components/Radio'; import Radio from '@/components/Radio';
import { splitText_token } from '@/utils/file'; import { splitText_token } from '@/utils/file';
import { TrainingTypeEnum } from '@/constants/plugin'; import { TrainingTypeEnum } from '@/constants/plugin';
@@ -32,7 +32,7 @@ const modeMap = {
price: 4, price: 4,
isPrompt: true isPrompt: true
}, },
subsection: { index: {
maxLen: 800, maxLen: 800,
slideLen: 300, slideLen: 300,
price: 0.4, price: 0.4,
@@ -53,7 +53,7 @@ const SelectFileModal = ({
const { toast } = useToast(); const { toast } = useToast();
const [prompt, setPrompt] = useState(''); const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.subsection); const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.index);
const [fileTextArr, setFileTextArr] = useState<string[]>(['']); const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({ const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({
tokens: 0, tokens: 0,
@@ -108,9 +108,9 @@ const SelectFileModal = ({
mutationFn: async () => { mutationFn: async () => {
if (splitRes.chunks.length === 0) return; if (splitRes.chunks.length === 0) return;
await postSplitData({ await postKbDataFromList({
kbId, kbId,
chunks: splitRes.chunks, data: splitRes.chunks.map((text) => ({ q: text, a: '' })),
prompt: `下面是"${prompt || '一段长文本'}"`, prompt: `下面是"${prompt || '一段长文本'}"`,
mode mode
}); });
@@ -195,11 +195,11 @@ const SelectFileModal = ({
<Radio <Radio
ml={3} ml={3}
list={[ list={[
{ label: '直接分段', value: 'subsection' }, { label: '直接分段', value: 'index' },
{ label: 'QA拆分', value: 'qa' } { label: 'QA拆分', value: 'qa' }
]} ]}
value={mode} value={mode}
onChange={(e) => setMode(e as 'subsection' | 'qa')} onChange={(e) => setMode(e as 'index' | 'qa')}
/> />
</Flex> </Flex>
{/* 内容介绍 */} {/* 内容介绍 */}

View File

@@ -7,49 +7,61 @@ import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat'; import { ChatRoleEnum } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData'; import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { TrainingTypeEnum } from '@/constants/plugin';
import { ERROR_ENUM } from '../errorCode'; import { ERROR_ENUM } from '../errorCode';
// 每次最多选 1 组 export async function generateQA(): Promise<any> {
const listLen = 1; const maxProcess = Number(process.env.QA_MAX_PROCESS || 10);
if (global.qaQueueLen >= maxProcess) return;
global.qaQueueLen++;
let trainingId = '';
let userId = '';
export async function generateQA(trainingId: string): Promise<any> {
try { try {
// 找出一个需要生成的 dataItem (4分钟锁) // 找出一个需要生成的 dataItem (4分钟锁)
const data = await TrainingData.findOneAndUpdate( const data = await TrainingData.findOneAndUpdate(
{ {
_id: trainingId, mode: TrainingTypeEnum.qa,
lockTime: { $lte: Date.now() - 4 * 60 * 1000 } lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) }
}, },
{ {
lockTime: new Date() lockTime: new Date()
} }
); ).select({
_id: 1,
userId: 1,
kbId: 1,
prompt: 1,
q: 1
});
if (!data || data.qaList.length === 0) { /* 无待生成的任务 */
await TrainingData.findOneAndDelete({ if (!data) {
_id: trainingId, global.qaQueueLen--;
qaList: [], !global.qaQueueLen && console.log(`没有需要【QA】的数据`);
vectorList: []
});
return; return;
} }
const qaList: string[] = data.qaList.slice(-listLen); trainingId = data._id;
userId = String(data.userId);
const kbId = String(data.kbId);
// 余额校验并获取 openapi Key // 余额校验并获取 openapi Key
const { userOpenAiKey, systemAuthKey } = await getApiKey({ const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: OpenAiChatEnum.GPT35, model: OpenAiChatEnum.GPT35,
userId: data.userId, userId,
type: 'training' type: 'training'
}); });
console.log(`正在生成一组QA, 包含 ${qaList.length} 组文本。ID: ${data._id}`); console.log(`正在生成一组QA。ID: ${trainingId}`);
const startTime = Date.now(); const startTime = Date.now();
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
const response = await Promise.all( const response = await Promise.all(
qaList.map((text) => [data.q].map((text) =>
modelServiceToolMap[OpenAiChatEnum.GPT35] modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({ .chatCompletion({
apiKey: userOpenAiKey || systemAuthKey, apiKey: userOpenAiKey || systemAuthKey,
@@ -100,24 +112,19 @@ A2:
// 创建 向量生成 队列 // 创建 向量生成 队列
pushDataToKb({ pushDataToKb({
kbId: data.kbId, kbId,
data: responseList, data: responseList,
userId: data.userId userId,
mode: TrainingTypeEnum.index
}); });
// 删除 QA 队列。如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个 // delete data from training
if (data.vectorList.length <= listLen) { await TrainingData.findByIdAndDelete(data._id);
await TrainingData.findByIdAndDelete(data._id);
} else {
await TrainingData.findByIdAndUpdate(data._id, {
qaList: data.qaList.slice(0, -listLen),
lockTime: new Date('2000/1/1')
});
}
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`); console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
generateQA(trainingId); global.qaQueueLen--;
generateQA();
} catch (err: any) { } catch (err: any) {
// log // log
if (err?.response) { if (err?.response) {
@@ -130,25 +137,28 @@ A2:
// openai 账号异常或者账号余额不足,删除任务 // openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务'); console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId); await TrainingData.deleteMany({
return; userId
});
return generateQA();
} }
// unlock // unlock
global.qaQueueLen--;
await TrainingData.findByIdAndUpdate(trainingId, { await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1') lockTime: new Date('2000/1/1')
}); });
// 频率限制 // 频率限制
if (err?.response?.statusText === 'Too Many Requests') { if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制,30s后尝试'); console.log('生成向量次数限制,20s后尝试');
return setTimeout(() => { return setTimeout(() => {
generateQA(trainingId); generateQA();
}, 30000); }, 20000);
} }
setTimeout(() => { setTimeout(() => {
generateQA(trainingId); generateQA();
}, 1000); }, 1000);
} }
} }

View File

@@ -3,104 +3,109 @@ import { insertKbItem, PgClient } from '@/service/pg';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
import { TrainingData } from '../models/trainingData'; import { TrainingData } from '../models/trainingData';
import { ERROR_ENUM } from '../errorCode'; import { ERROR_ENUM } from '../errorCode';
import { TrainingTypeEnum } from '@/constants/plugin';
// 每次最多选 5 组
const listLen = 5;
/* 索引生成队列。每导入一次,就是一个单独的线程 */ /* 索引生成队列。每导入一次,就是一个单独的线程 */
export async function generateVector(trainingId: string): Promise<any> { export async function generateVector(): Promise<any> {
const maxProcess = Number(process.env.VECTOR_MAX_PROCESS || 10);
if (global.vectorQueueLen >= maxProcess) return;
global.vectorQueueLen++;
let trainingId = '';
let userId = '';
try { try {
// 找出一个需要生成的 dataItem (2分钟锁)
const data = await TrainingData.findOneAndUpdate( const data = await TrainingData.findOneAndUpdate(
{ {
_id: trainingId, mode: TrainingTypeEnum.index,
lockTime: { $lte: Date.now() - 2 * 60 * 1000 } lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) }
}, },
{ {
lockTime: new Date() lockTime: new Date()
} }
); ).select({
_id: 1,
userId: 1,
kbId: 1,
q: 1,
a: 1
});
/* 无待生成的任务 */
if (!data) { if (!data) {
await TrainingData.findOneAndDelete({ global.vectorQueueLen--;
_id: trainingId, !global.vectorQueueLen && console.log(`没有需要【索引】的数据`);
qaList: [],
vectorList: []
});
return; return;
} }
const userId = String(data.userId); trainingId = data._id;
userId = String(data.userId);
const kbId = String(data.kbId); const kbId = String(data.kbId);
const dataItems: { q: string; a: string }[] = data.vectorList.slice(-listLen).map((item) => ({ const dataItems = [
q: item.q, {
a: item.a q: data.q,
})); a: data.a
}
];
// 过滤重复的 qa 内容 // 过滤重复的 qa 内容
const searchRes = await Promise.allSettled( // const searchRes = await Promise.allSettled(
dataItems.map(async ({ q, a = '' }) => { // dataItems.map(async ({ q, a = '' }) => {
if (!q) { // if (!q) {
return Promise.reject('q为空'); // return Promise.reject('q为空');
} // }
q = q.replace(/\\n/g, '\n'); // q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n'); // a = a.replace(/\\n/g, '\n');
// Exactly the same data, not push // // Exactly the same data, not push
try { // try {
const count = await PgClient.count('modelData', { // const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]] // where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
}); // });
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
if (filterData.length > 0) { // if (count > 0) {
// 生成词向量 // return Promise.reject('已经存在');
const vectors = await openaiEmbedding({ // }
input: filterData.map((item) => item.q), // } catch (error) {
userId, // error;
type: 'training' // }
}); // return Promise.resolve({
// q,
// a
// });
// })
// );
// const filterData = searchRes
// .filter((item) => item.status === 'fulfilled')
// .map<{ q: string; a: string }>((item: any) => item.value);
// 生成结果插入到 pg // 生成词向量
await insertKbItem({ const vectors = await openaiEmbedding({
userId, input: dataItems.map((item) => item.q),
kbId, userId,
data: vectors.map((vector, i) => ({ type: 'training'
q: filterData[i].q, });
a: filterData[i].a,
vector
}))
});
}
// 删除 mongo 训练队列. 如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个 // 生成结果插入到 pg
if (data.vectorList.length <= listLen) { await insertKbItem({
await TrainingData.findByIdAndDelete(trainingId); userId,
console.log(`全部向量生成完毕: ${trainingId}`); kbId,
} else { data: vectors.map((vector, i) => ({
await TrainingData.findByIdAndUpdate(trainingId, { q: dataItems[i].q,
vectorList: data.vectorList.slice(0, -listLen), a: dataItems[i].a,
lockTime: new Date('2000/1/1') vector
}); }))
console.log(`生成向量成功: ${trainingId}`); });
generateVector(trainingId);
} // delete data from training
await TrainingData.findByIdAndDelete(data._id);
console.log(`生成向量成功: ${data._id}`);
global.vectorQueueLen--;
generateVector();
} catch (err: any) { } catch (err: any) {
// log // log
if (err?.response) { if (err?.response) {
@@ -113,25 +118,28 @@ export async function generateVector(trainingId: string): Promise<any> {
// openai 账号异常或者账号余额不足,删除任务 // openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务'); console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId); await TrainingData.deleteMany({
return; userId
});
return generateVector();
} }
// unlock // unlock
global.vectorQueueLen--;
await TrainingData.findByIdAndUpdate(trainingId, { await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1') lockTime: new Date('2000/1/1')
}); });
// 频率限制 // 频率限制
if (err?.response?.statusText === 'Too Many Requests') { if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制,30s后尝试'); console.log('生成向量次数限制,20s后尝试');
return setTimeout(() => { return setTimeout(() => {
generateVector(trainingId); generateVector();
}, 30000); }, 20000);
} }
setTimeout(() => { setTimeout(() => {
generateVector(trainingId); generateVector();
}, 1000); }, 1000);
} }
} }

View File

@@ -1,9 +1,9 @@
/* 模型的知识库 */ /* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose'; import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema'; import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema';
import { TrainingTypeMap } from '@/constants/plugin';
// pgList and vectorList, Only one of them will work // pgList and vectorList, Only one of them will work
const TrainingDataSchema = new Schema({ const TrainingDataSchema = new Schema({
userId: { userId: {
type: Schema.Types.ObjectId, type: Schema.Types.ObjectId,
@@ -19,18 +19,27 @@ const TrainingDataSchema = new Schema({
type: Date, type: Date,
default: () => new Date('2000/1/1') default: () => new Date('2000/1/1')
}, },
vectorList: { mode: {
type: [{ q: String, a: String }], type: String,
default: [] enum: Object.keys(TrainingTypeMap),
required: true
}, },
prompt: { prompt: {
// 拆分时的提示词 // 拆分时的提示词
type: String, type: String,
default: '' default: ''
}, },
qaList: { q: {
type: [String], // 如果是
default: [] type: String,
default: ''
},
a: {
type: String,
default: ''
},
vectorList: {
type: Object
} }
}); });

View File

@@ -1,8 +1,7 @@
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { generateQA } from './events/generateQA';
import { generateVector } from './events/generateVector';
import tunnel from 'tunnel'; import tunnel from 'tunnel';
import { TrainingData } from './mongo'; import { TrainingData } from './mongo';
import { startQueue } from './utils/tools';
/** /**
* 连接 MongoDB 数据库 * 连接 MongoDB 数据库
@@ -38,7 +37,10 @@ export async function connectToDatabase(): Promise<void> {
}); });
} }
startTrain(); global.qaQueueLen = 0;
global.vectorQueueLen = 0;
startQueue();
// 5 分钟后解锁不正常的数据,并触发开始训练 // 5 分钟后解锁不正常的数据,并触发开始训练
setTimeout(async () => { setTimeout(async () => {
await TrainingData.updateMany( await TrainingData.updateMany(
@@ -49,24 +51,10 @@ export async function connectToDatabase(): Promise<void> {
lockTime: new Date('2000/1/1') lockTime: new Date('2000/1/1')
} }
); );
startTrain(); startQueue();
}, 5 * 60 * 1000); }, 5 * 60 * 1000);
} }
async function startTrain() {
const qa = await TrainingData.find({
qaList: { $exists: true, $ne: [] }
});
qa.map((item) => generateQA(String(item._id)));
const vector = await TrainingData.find({
vectorList: { $exists: true, $ne: [] }
});
vector.map((item) => generateVector(String(item._id)));
}
export * from './models/authCode'; export * from './models/authCode';
export * from './models/chat'; export * from './models/chat';
export * from './models/model'; export * from './models/model';

View File

@@ -14,8 +14,8 @@ export const connectPg = async () => {
password: process.env.PG_PASSWORD, password: process.env.PG_PASSWORD,
database: process.env.PG_DB_NAME, database: process.env.PG_DB_NAME,
max: 20, max: 20,
idleTimeoutMillis: 30000, idleTimeoutMillis: 60000,
connectionTimeoutMillis: 2000 connectionTimeoutMillis: 20000
}); });
global.pgClient.on('error', (err) => { global.pgClient.on('error', (err) => {

View File

@@ -45,7 +45,7 @@ export const jsonRes = <T = any>(
} else if (openaiError[error?.response?.statusText]) { } else if (openaiError[error?.response?.statusText]) {
msg = openaiError[error.response.statusText]; msg = openaiError[error.response.statusText];
} }
console.log(error); console.log(error?.message || error);
} }
res.json({ res.json({

View File

@@ -2,6 +2,8 @@ import type { NextApiResponse, NextApiHandler, NextApiRequest } from 'next';
import NextCors from 'nextjs-cors'; import NextCors from 'nextjs-cors';
import crypto from 'crypto'; import crypto from 'crypto';
import jwt from 'jsonwebtoken'; import jwt from 'jsonwebtoken';
import { generateQA } from '../events/generateQA';
import { generateVector } from '../events/generateVector';
/* 密码加密 */ /* 密码加密 */
export const hashPassword = (psw: string) => { export const hashPassword = (psw: string) => {
@@ -45,7 +47,7 @@ export function withNextCors(handler: NextApiHandler): NextApiHandler {
req: NextApiRequest, req: NextApiRequest,
res: NextApiResponse res: NextApiResponse
) { ) {
const methods = ['GET', 'HEAD', 'PUT', 'PATCH', 'POST', 'DELETE']; const methods = ['GET', 'eHEAD', 'PUT', 'PATCH', 'POST', 'DELETE'];
const origin = req.headers.origin; const origin = req.headers.origin;
await NextCors(req, res, { await NextCors(req, res, {
methods, methods,
@@ -56,3 +58,15 @@ export function withNextCors(handler: NextApiHandler): NextApiHandler {
return handler(req, res); return handler(req, res);
}; };
} }
export const startQueue = () => {
const qaMax = Number(process.env.QA_MAX_PROCESS || 10);
const vectorMax = Number(process.env.VECTOR_MAX_PROCESS || 10);
for (let i = 0; i < qaMax; i++) {
generateQA();
}
for (let i = 0; i < vectorMax; i++) {
generateVector();
}
};

View File

@@ -9,6 +9,8 @@ declare global {
var particlesJS: any; var particlesJS: any;
var grecaptcha: any; var grecaptcha: any;
var QRCode: any; var QRCode: any;
var qaQueueLen: number;
var vectorQueueLen: number;
interface Window { interface Window {
['pdfjs-dist/build/pdf']: any; ['pdfjs-dist/build/pdf']: any;

View File

@@ -74,9 +74,10 @@ export interface TrainingDataSchema {
userId: string; userId: string;
kbId: string; kbId: string;
lockTime: Date; lockTime: Date;
vectorList: { q: string; a: string }[]; mode: `${TrainingTypeEnum}`;
prompt: string; prompt: string;
qaList: string[]; q: string;
a: string;
} }
export interface ChatSchema { export interface ChatSchema {