Add bill of training and rate of file upload (#339)

This commit is contained in:
Archer
2023-09-21 21:02:44 +08:00
committed by GitHub
parent e7e0677291
commit 814c5b3d3c
41 changed files with 401 additions and 263 deletions

View File

@@ -6,7 +6,7 @@ import { sseResponseEventEnum } from '@/constants/chat';
import { sseResponse } from '@/service/utils/tools';
import { AppModuleItemType } from '@/types/app';
import { dispatchModules } from '../openapi/v1/chat/completions';
import { pushTaskBill } from '@/service/events/pushBill';
import { pushTaskBill } from '@/service/common/bill/push';
import { BillSourceEnum } from '@/constants/user';
import { ChatItemType } from '@/types/chat';

View File

@@ -0,0 +1,46 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Bill } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { BillSourceEnum } from '@/constants/user';
import { CreateTrainingBillType } from '@/api/common/bill/index.d';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { name } = req.body as CreateTrainingBillType;
const { userId } = await authUser({ req, authToken: true });
await connectToDatabase();
const { _id } = await Bill.create({
userId,
appName: name,
source: BillSourceEnum.training,
list: [
{
moduleName: '索引生成',
model: 'embedding',
amount: 0,
tokenLen: 0
},
{
moduleName: 'QA 拆分',
model: global.qaModel.name,
amount: 0,
tokenLen: 0
}
],
total: 0
});
jsonRes<string>(res, {
data: _id
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -11,6 +11,7 @@ import { DatasetDataItemType } from '@/types/core/dataset/data';
import { countPromptTokens } from '@/utils/common/tiktoken';
export type Props = {
billId?: string;
kbId: string;
data: DatasetDataItemType;
};
@@ -19,63 +20,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
try {
await connectToDatabase();
const { kbId, data = { q: '', a: '' } } = req.body as Props;
if (!kbId || !data?.q) {
throw new Error('缺少参数');
}
// 凭证校验
const { userId } = await authUser({ req });
// auth kb
const kb = await authKb({ kbId, userId });
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// token check
const token = countPromptTokens(q, 'system');
if (token > getVectorModel(kb.vectorModel).maxToken) {
throw new Error('Over Tokens');
}
const { rows: existsRows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM ${PgDatasetTableName}
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = existsRows[0]?.exists || false;
if (exists) {
throw new Error('已经存在完全一致的数据');
}
const { vectors } = await getVector({
model: kb.vectorModel,
input: [q],
userId
});
const response = await insertData2Dataset({
userId,
kbId,
data: [
{
q,
a,
source: data.source,
vector: vectors[0]
}
]
});
// @ts-ignore
const id = response?.rows?.[0]?.id || '';
jsonRes(res, {
data: id
data: await getVectorAndInsertDataset({
...req.body,
userId
})
});
} catch (err) {
jsonRes(res, {
@@ -84,3 +36,59 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
});
}
});
export async function getVectorAndInsertDataset(
props: Props & { userId: string }
): Promise<string> {
const { kbId, data, userId, billId } = props;
if (!kbId || !data?.q) {
return Promise.reject('缺少参数');
}
// auth kb
const kb = await authKb({ kbId, userId });
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// token check
const token = countPromptTokens(q, 'system');
if (token > getVectorModel(kb.vectorModel).maxToken) {
return Promise.reject('Over Tokens');
}
const { rows: existsRows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM ${PgDatasetTableName}
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = existsRows[0]?.exists || false;
if (exists) {
return Promise.reject('已经存在完全一致的数据');
}
const { vectors } = await getVector({
model: kb.vectorModel,
input: [q],
userId,
billId
});
const response = await insertData2Dataset({
userId,
kbId,
data: [
{
...data,
q,
a,
vector: vectors[0]
}
]
});
// @ts-ignore
return response?.rows?.[0]?.id || '';
}

View File

@@ -19,7 +19,7 @@ const modeMap = {
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as PushDataProps;
const { kbId, data, mode = TrainingModeEnum.index } = req.body as PushDataProps;
if (!kbId || !Array.isArray(data)) {
throw new Error('KbId or data is empty');
@@ -40,11 +40,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
jsonRes<PushDataResponse>(res, {
data: await pushDataToKb({
kbId,
data,
userId,
mode,
prompt
...req.body,
userId
})
});
} catch (err) {
@@ -60,7 +57,8 @@ export async function pushDataToKb({
kbId,
data,
mode,
prompt
prompt,
billId
}: { userId: string } & PushDataProps): Promise<PushDataResponse> {
const [kb, vectorModel] = await Promise.all([
authKb({
@@ -150,6 +148,7 @@ export async function pushDataToKb({
kbId,
mode,
prompt,
billId,
vectorModel: vectorModel.model
}))
);
@@ -163,6 +162,9 @@ export async function pushDataToKb({
export const config = {
api: {
bodyParser: {
sizeLimit: '10mb'
},
responseLimit: '12mb'
}
};

View File

@@ -1,10 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, KB, App, TrainingData } from '@/service/mongo';
import { connectToDatabase, KB, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { PgDatasetTableName } from '@/constants/plugin';
import { GridFSStorage } from '@/service/lib/gridfs';
import { Types } from 'mongoose';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -25,7 +26,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// delete training data
await TrainingData.deleteMany({
userId,
kbId: { $in: deletedIds }
kbId: { $in: deletedIds.map((id) => new Types.ObjectId(id)) }
});
// delete all pg data

View File

@@ -0,0 +1,38 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { GridFSStorage } from '@/service/lib/gridfs';
import { MarkFileUsedProps } from '@/api/core/dataset/file.d';
import { Types } from 'mongoose';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
await connectToDatabase();
const { fileIds } = req.body as MarkFileUsedProps;
const { userId } = await authUser({ req, authToken: true });
const gridFs = new GridFSStorage('dataset', userId);
const collection = gridFs.Collection();
await collection.updateMany(
{
_id: { $in: fileIds.map((id) => new Types.ObjectId(id)) },
['metadata.userId']: userId
},
{
$set: {
['metadata.datasetUsed']: true
}
}
);
jsonRes(res, {});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -3,11 +3,12 @@ import { jsonRes } from '@/service/response';
import { authBalanceByUid, authUser } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
import { getAIChatApi, axiosConfig } from '@/service/lib/openai';
import { pushGenerateVectorBill } from '@/service/events/pushBill';
import { pushGenerateVectorBill } from '@/service/common/bill/push';
type Props = {
model: string;
input: string[];
billId?: string;
};
type Response = {
tokenLen: number;
@@ -38,7 +39,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function getVector({
model = 'text-embedding-ada-002',
userId,
input
input,
billId
}: { userId?: string } & Props) {
userId && (await authBalanceByUid(userId));
@@ -82,7 +84,8 @@ export async function getVector({
pushGenerateVectorBill({
userId,
tokenLen: result.tokenLen,
model
model,
billId
});
return result;

View File

@@ -23,7 +23,7 @@ import { type ChatCompletionRequestMessage } from 'openai';
import { TaskResponseKeyEnum } from '@/constants/chat';
import { FlowModuleTypeEnum, initModuleType } from '@/constants/flow';
import { AppModuleItemType, RunningModuleItemType } from '@/types/app';
import { pushTaskBill } from '@/service/events/pushBill';
import { pushTaskBill } from '@/service/common/bill/push';
import { BillSourceEnum } from '@/constants/user';
import { ChatHistoryItemResType } from '@/types/chat';
import { UserModelSchema } from '@/types/mongoSchema';

View File

@@ -38,3 +38,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
});
}
}
export const config = {
api: {
responseLimit: '32mb'
}
};

View File

@@ -84,7 +84,7 @@ const defaultQAModel = {
maxToken: 16000,
price: 0
};
const defaultExtractModel: FunctionModelItemType = {
export const defaultExtractModel: FunctionModelItemType = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,
@@ -92,7 +92,7 @@ const defaultExtractModel: FunctionModelItemType = {
prompt: '',
functionCall: true
};
const defaultCQModel: FunctionModelItemType = {
export const defaultCQModel: FunctionModelItemType = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,