Files
FastGPT/projects/app/src/service/events/generateQA.ts
2024-01-10 23:35:04 +08:00

259 lines
7.0 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { MongoDatasetTraining } from '@fastgpt/service/core/dataset/training/schema';
import { pushQABill } from '@/service/support/wallet/bill/push';
import { DatasetDataIndexTypeEnum, TrainingModeEnum } from '@fastgpt/global/core/dataset/constant';
import { sendOneInform } from '../support/user/inform/api';
import { getAIApi } from '@fastgpt/service/core/ai/config';
import type { ChatMessageItemType } from '@fastgpt/global/core/ai/type.d';
import { addLog } from '@fastgpt/service/common/system/log';
import { splitText2Chunks } from '@fastgpt/global/common/string/textSplitter';
import { replaceVariable } from '@fastgpt/global/common/string/tools';
import { Prompt_AgentQA } from '@/global/core/prompt/agent';
import { getErrText } from '@fastgpt/global/common/error/utils';
import { authTeamBalance } from '../support/permission/auth/bill';
import type { PushDatasetDataChunkProps } from '@fastgpt/global/core/dataset/api.d';
import { UserErrEnum } from '@fastgpt/global/common/error/code/user';
import { lockTrainingDataByTeamId } from '@fastgpt/service/core/dataset/training/controller';
import { pushDataToDatasetCollection } from '@/service/core/dataset/data/controller';
const reduceQueue = () => {
global.qaQueueLen = global.qaQueueLen > 0 ? global.qaQueueLen - 1 : 0;
return global.vectorQueueLen === 0;
};
export async function generateQA(): Promise<any> {
if (global.qaQueueLen >= global.systemEnv.qaMaxProcess) return;
global.qaQueueLen++;
// get training data
const {
data,
text,
done = false,
error = false
} = await (async () => {
try {
const data = await MongoDatasetTraining.findOneAndUpdate(
{
mode: TrainingModeEnum.qa,
lockTime: { $lte: new Date(Date.now() - 6 * 60 * 1000) }
},
{
lockTime: new Date()
}
)
.select({
_id: 1,
userId: 1,
teamId: 1,
tmbId: 1,
datasetId: 1,
collectionId: 1,
q: 1,
model: 1,
chunkIndex: 1,
billId: 1,
prompt: 1
})
.lean();
// task preemption
if (!data) {
return {
done: true
};
}
return {
data,
text: data.q
};
} catch (error) {
console.log(`Get Training Data error`, error);
return {
error: true
};
}
})();
if (done || !data) {
if (reduceQueue()) {
console.log(`【QA】Task Done`);
}
return;
}
if (error) {
reduceQueue();
return generateQA();
}
// auth balance
try {
await authTeamBalance(data.teamId);
} catch (error: any) {
if (error?.statusText === UserErrEnum.balanceNotEnough) {
// send inform and lock data
try {
sendOneInform({
type: 'system',
title: '文本训练任务中止',
content:
'该团队账号余额不足,文本训练任务中止,重新充值后将会继续。暂停的任务将在 7 天后被删除。',
tmbId: data.tmbId
});
console.log('余额不足暂停【QA】生成任务');
lockTrainingDataByTeamId(data.teamId);
} catch (error) {}
}
reduceQueue();
return generateQA();
}
try {
const startTime = Date.now();
const model = data.model ?? global.qaModels[0].model;
const prompt = `${data.prompt || Prompt_AgentQA.description}
${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;
// request LLM to get QA
const messages: ChatMessageItemType[] = [
{
role: 'user',
content: prompt
}
];
const ai = getAIApi(undefined, 600000);
const chatResponse = await ai.chat.completions.create({
model,
temperature: 0.3,
messages,
stream: false
});
const answer = chatResponse.choices?.[0].message?.content || '';
const qaArr = formatSplitText(answer, text); // 格式化后的QA对
// get vector and insert
const { insertLen } = await pushDataToDatasetCollection({
teamId: data.teamId,
tmbId: data.tmbId,
collectionId: data.collectionId,
trainingMode: TrainingModeEnum.chunk,
data: qaArr.map((item) => ({
...item,
chunkIndex: data.chunkIndex
})),
billId: data.billId
});
// delete data from training
await MongoDatasetTraining.findByIdAndDelete(data._id);
addLog.info(`QA Training Finish`, {
time: `${(Date.now() - startTime) / 1000}s`,
splitLength: qaArr.length,
usage: chatResponse.usage
});
// add bill
if (insertLen > 0) {
pushQABill({
teamId: data.teamId,
tmbId: data.tmbId,
inputTokens: chatResponse.usage?.prompt_tokens || 0,
outputTokens: chatResponse.usage?.completion_tokens || 0,
billId: data.billId,
model
});
} else {
addLog.info(`QA result 0:`, { answer });
}
reduceQueue();
generateQA();
} catch (err: any) {
reduceQueue();
// log
if (err?.response) {
addLog.info('openai error: 生成QA错误', {
status: err.response?.status,
stateusText: err.response?.statusText,
data: err.response?.data
});
} else {
console.log(err);
addLog.error(getErrText(err, '生成 QA 错误'));
}
// message error or openai account error
if (
err?.message === 'invalid message format' ||
err.response?.data?.error?.type === 'invalid_request_error' ||
err?.code === 500
) {
addLog.info('invalid message format', {
text
});
try {
await MongoDatasetTraining.findByIdAndUpdate(data._id, {
lockTime: new Date('2998/5/5')
});
} catch (error) {}
return generateQA();
}
setTimeout(() => {
generateQA();
}, 1000);
}
}
/**
* 检查文本是否按格式返回
*/
function formatSplitText(text: string, rawText: string) {
text = text.replace(/\\n/g, '\n'); // 将换行符替换为空格
const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q|$)/g; // 匹配Q和A的正则表达式
const matches = text.matchAll(regex); // 获取所有匹配到的结果
const result: PushDatasetDataChunkProps[] = []; // 存储最终的结果
for (const match of matches) {
const q = match[2] || '';
const a = match[5] || '';
if (q) {
result.push({
q,
a,
indexes: [
{
defaultIndex: true,
type: DatasetDataIndexTypeEnum.qa,
text: `${q}\n${a.trim().replace(/\n\s*/g, '\n')}`
}
]
});
}
}
// empty result. direct split chunk
if (result.length === 0) {
const { chunks } = splitText2Chunks({ text: rawText, chunkLen: 512, countTokens: false });
chunks.forEach((chunk) => {
result.push({
q: chunk,
a: '',
indexes: [
{
defaultIndex: true,
type: DatasetDataIndexTypeEnum.chunk,
text: chunk
}
]
});
});
}
return result;
}