perf: open push data api

This commit is contained in:
archer
2023-08-29 10:40:48 +08:00
parent 19d7edb585
commit e0de04dddb
6 changed files with 51 additions and 77 deletions

View File

@@ -8,6 +8,7 @@ import { PgTrainingTableName, TrainingModeEnum } from '@/constants/plugin';
import { startQueue } from '@/service/utils/tools';
import { PgClient } from '@/service/pg';
import { modelToolMap } from '@/utils/plugin';
import { getVectorModel } from '@/service/utils/data';
export type DateItemType = { a: string; q: string; source?: string };
@@ -22,17 +23,25 @@ export type Response = {
insertLen: number;
};
const modeMaxToken = {
[TrainingModeEnum.index]: 6000,
[TrainingModeEnum.qa]: 12000
const modeMap = {
[TrainingModeEnum.index]: true,
[TrainingModeEnum.qa]: true
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt } = req.body as Props;
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
throw new Error('KbId or data is empty');
}
if (modeMap[mode] === undefined) {
throw new Error('Mode is error');
}
if (data.length > 500) {
throw new Error('Data is too long, max 500');
}
await connectToDatabase();
@@ -64,25 +73,42 @@ export async function pushDataToKb({
mode,
prompt
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
kbId
});
const [kb, vectorModel] = await Promise.all([
authKb({
userId,
kbId
}),
(async () => {
if (mode === TrainingModeEnum.index) {
const vectorModel = (await KB.findById(kbId, 'vectorModel'))?.vectorModel;
return getVectorModel(vectorModel || global.vectorModels[0].model);
}
return global.vectorModels[0];
})()
]);
const modeMaxToken = {
[TrainingModeEnum.index]: vectorModel.maxToken,
[TrainingModeEnum.qa]: global.qaModel.maxToken * 0.8
};
// 过滤重复的 qa 内容
const set = new Set();
const filterData: DateItemType[] = [];
data.forEach((item) => {
if (!item.q) return;
const text = item.q + item.a;
// count token
// count q token
const token = modelToolMap.countTokens({
model: 'gpt-3.5-turbo',
messages: [{ obj: 'System', value: item.q }]
});
if (token > modeMaxToken[TrainingModeEnum.qa]) {
if (token > modeMaxToken[mode]) {
return;
}
@@ -138,15 +164,8 @@ export async function pushDataToKb({
.filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value);
const vectorModel = await (async () => {
if (mode === TrainingModeEnum.index) {
return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model;
}
return global.vectorModels[0].model;
})();
// 插入记录
await TrainingData.insertMany(
const insertRes = await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
@@ -155,21 +174,21 @@ export async function pushDataToKb({
kbId,
mode,
prompt,
vectorModel
vectorModel: vectorModel.model
}))
);
insertData.length > 0 && startQueue();
insertRes.length > 0 && startQueue();
return {
insertLen: insertData.length
insertLen: insertRes.length
};
}
export const config = {
api: {
bodyParser: {
sizeLimit: '20mb'
sizeLimit: '12mb'
}
}
};

View File

@@ -1,51 +0,0 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import axios from 'axios';
import { axiosConfig } from '@/service/ai/openai';
export type Props = {
input: string;
};
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
await authUser({ req });
const result = await sensitiveCheck(req.body);
jsonRes(res, {
data: result,
message: result
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}
export async function sensitiveCheck({ input }: Props) {
const response = await axios({
...axiosConfig(),
method: 'POST',
url: `/moderations`,
data: {
input
}
});
const data = (response.data.results?.[0]?.category_scores as Record<string, number>) || {};
const values = Object.values(data);
for (const val of values) {
if (val > 0.2) {
return Promise.reject('您的内容不合规');
}
}
return '';
}

View File

@@ -66,7 +66,7 @@ const ChunkImport = ({ kbId }: { kbId: string }) => {
// subsection import
let success = 0;
const step = 500;
const step = 300;
for (let i = 0; i < chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,

View File

@@ -54,7 +54,7 @@ const CsvImport = ({ kbId }: { kbId: string }) => {
// subsection import
let success = 0;
const step = 500;
const step = 300;
for (let i = 0; i < filterChunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,

View File

@@ -53,7 +53,7 @@ const QAImport = ({ kbId }: { kbId: string }) => {
// subsection import
let success = 0;
const step = 300;
const step = 200;
for (let i = 0; i < chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,

View File

@@ -156,6 +156,12 @@ const Info = (
</Box>
<Box flex={[1, '0 0 300px']}>{getValues('vectorModel').name}</Box>
</Flex>
<Flex mt={8} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}>
MaxTokens
</Box>
<Box flex={[1, '0 0 300px']}>{getValues('vectorModel').maxToken}</Box>
</Flex>
<Flex mt={5} w={'100%'} alignItems={'center'}>
<Box flex={['0 0 90px', '0 0 160px']} w={0}>