mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-30 18:48:55 +00:00
Add bill of training and rate of file upload (#339)
This commit is contained in:
@@ -6,7 +6,7 @@ import { sseResponseEventEnum } from '@/constants/chat';
|
||||
import { sseResponse } from '@/service/utils/tools';
|
||||
import { AppModuleItemType } from '@/types/app';
|
||||
import { dispatchModules } from '../openapi/v1/chat/completions';
|
||||
import { pushTaskBill } from '@/service/events/pushBill';
|
||||
import { pushTaskBill } from '@/service/common/bill/push';
|
||||
import { BillSourceEnum } from '@/constants/user';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
|
||||
|
46
client/src/pages/api/common/bill/createTrainingBill.ts
Normal file
46
client/src/pages/api/common/bill/createTrainingBill.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, Bill } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { BillSourceEnum } from '@/constants/user';
|
||||
import { CreateTrainingBillType } from '@/api/common/bill/index.d';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
const { name } = req.body as CreateTrainingBillType;
|
||||
|
||||
const { userId } = await authUser({ req, authToken: true });
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
const { _id } = await Bill.create({
|
||||
userId,
|
||||
appName: name,
|
||||
source: BillSourceEnum.training,
|
||||
list: [
|
||||
{
|
||||
moduleName: '索引生成',
|
||||
model: 'embedding',
|
||||
amount: 0,
|
||||
tokenLen: 0
|
||||
},
|
||||
{
|
||||
moduleName: 'QA 拆分',
|
||||
model: global.qaModel.name,
|
||||
amount: 0,
|
||||
tokenLen: 0
|
||||
}
|
||||
],
|
||||
total: 0
|
||||
});
|
||||
|
||||
jsonRes<string>(res, {
|
||||
data: _id
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
@@ -11,6 +11,7 @@ import { DatasetDataItemType } from '@/types/core/dataset/data';
|
||||
import { countPromptTokens } from '@/utils/common/tiktoken';
|
||||
|
||||
export type Props = {
|
||||
billId?: string;
|
||||
kbId: string;
|
||||
data: DatasetDataItemType;
|
||||
};
|
||||
@@ -19,63 +20,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
try {
|
||||
await connectToDatabase();
|
||||
|
||||
const { kbId, data = { q: '', a: '' } } = req.body as Props;
|
||||
|
||||
if (!kbId || !data?.q) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
// 凭证校验
|
||||
const { userId } = await authUser({ req });
|
||||
|
||||
// auth kb
|
||||
const kb = await authKb({ kbId, userId });
|
||||
|
||||
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
||||
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
||||
|
||||
// token check
|
||||
const token = countPromptTokens(q, 'system');
|
||||
|
||||
if (token > getVectorModel(kb.vectorModel).maxToken) {
|
||||
throw new Error('Over Tokens');
|
||||
}
|
||||
|
||||
const { rows: existsRows } = await PgClient.query(`
|
||||
SELECT COUNT(*) > 0 AS exists
|
||||
FROM ${PgDatasetTableName}
|
||||
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
|
||||
`);
|
||||
const exists = existsRows[0]?.exists || false;
|
||||
|
||||
if (exists) {
|
||||
throw new Error('已经存在完全一致的数据');
|
||||
}
|
||||
|
||||
const { vectors } = await getVector({
|
||||
model: kb.vectorModel,
|
||||
input: [q],
|
||||
userId
|
||||
});
|
||||
|
||||
const response = await insertData2Dataset({
|
||||
userId,
|
||||
kbId,
|
||||
data: [
|
||||
{
|
||||
q,
|
||||
a,
|
||||
source: data.source,
|
||||
vector: vectors[0]
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
const id = response?.rows?.[0]?.id || '';
|
||||
|
||||
jsonRes(res, {
|
||||
data: id
|
||||
data: await getVectorAndInsertDataset({
|
||||
...req.body,
|
||||
userId
|
||||
})
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
@@ -84,3 +36,59 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
export async function getVectorAndInsertDataset(
|
||||
props: Props & { userId: string }
|
||||
): Promise<string> {
|
||||
const { kbId, data, userId, billId } = props;
|
||||
if (!kbId || !data?.q) {
|
||||
return Promise.reject('缺少参数');
|
||||
}
|
||||
|
||||
// auth kb
|
||||
const kb = await authKb({ kbId, userId });
|
||||
|
||||
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
||||
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
||||
|
||||
// token check
|
||||
const token = countPromptTokens(q, 'system');
|
||||
|
||||
if (token > getVectorModel(kb.vectorModel).maxToken) {
|
||||
return Promise.reject('Over Tokens');
|
||||
}
|
||||
|
||||
const { rows: existsRows } = await PgClient.query(`
|
||||
SELECT COUNT(*) > 0 AS exists
|
||||
FROM ${PgDatasetTableName}
|
||||
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
|
||||
`);
|
||||
const exists = existsRows[0]?.exists || false;
|
||||
|
||||
if (exists) {
|
||||
return Promise.reject('已经存在完全一致的数据');
|
||||
}
|
||||
|
||||
const { vectors } = await getVector({
|
||||
model: kb.vectorModel,
|
||||
input: [q],
|
||||
userId,
|
||||
billId
|
||||
});
|
||||
|
||||
const response = await insertData2Dataset({
|
||||
userId,
|
||||
kbId,
|
||||
data: [
|
||||
{
|
||||
...data,
|
||||
q,
|
||||
a,
|
||||
vector: vectors[0]
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
return response?.rows?.[0]?.id || '';
|
||||
}
|
||||
|
@@ -19,7 +19,7 @@ const modeMap = {
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as PushDataProps;
|
||||
const { kbId, data, mode = TrainingModeEnum.index } = req.body as PushDataProps;
|
||||
|
||||
if (!kbId || !Array.isArray(data)) {
|
||||
throw new Error('KbId or data is empty');
|
||||
@@ -40,11 +40,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
|
||||
jsonRes<PushDataResponse>(res, {
|
||||
data: await pushDataToKb({
|
||||
kbId,
|
||||
data,
|
||||
userId,
|
||||
mode,
|
||||
prompt
|
||||
...req.body,
|
||||
userId
|
||||
})
|
||||
});
|
||||
} catch (err) {
|
||||
@@ -60,7 +57,8 @@ export async function pushDataToKb({
|
||||
kbId,
|
||||
data,
|
||||
mode,
|
||||
prompt
|
||||
prompt,
|
||||
billId
|
||||
}: { userId: string } & PushDataProps): Promise<PushDataResponse> {
|
||||
const [kb, vectorModel] = await Promise.all([
|
||||
authKb({
|
||||
@@ -150,6 +148,7 @@ export async function pushDataToKb({
|
||||
kbId,
|
||||
mode,
|
||||
prompt,
|
||||
billId,
|
||||
vectorModel: vectorModel.model
|
||||
}))
|
||||
);
|
||||
@@ -163,6 +162,9 @@ export async function pushDataToKb({
|
||||
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: {
|
||||
sizeLimit: '10mb'
|
||||
},
|
||||
responseLimit: '12mb'
|
||||
}
|
||||
};
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, KB, App, TrainingData } from '@/service/mongo';
|
||||
import { connectToDatabase, KB, TrainingData } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { PgDatasetTableName } from '@/constants/plugin';
|
||||
import { GridFSStorage } from '@/service/lib/gridfs';
|
||||
import { Types } from 'mongoose';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
@@ -25,7 +26,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
// delete training data
|
||||
await TrainingData.deleteMany({
|
||||
userId,
|
||||
kbId: { $in: deletedIds }
|
||||
kbId: { $in: deletedIds.map((id) => new Types.ObjectId(id)) }
|
||||
});
|
||||
|
||||
// delete all pg data
|
||||
|
38
client/src/pages/api/core/dataset/file/markUsed.ts
Normal file
38
client/src/pages/api/core/dataset/file/markUsed.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { GridFSStorage } from '@/service/lib/gridfs';
|
||||
import { MarkFileUsedProps } from '@/api/core/dataset/file.d';
|
||||
import { Types } from 'mongoose';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
await connectToDatabase();
|
||||
|
||||
const { fileIds } = req.body as MarkFileUsedProps;
|
||||
const { userId } = await authUser({ req, authToken: true });
|
||||
|
||||
const gridFs = new GridFSStorage('dataset', userId);
|
||||
const collection = gridFs.Collection();
|
||||
|
||||
await collection.updateMany(
|
||||
{
|
||||
_id: { $in: fileIds.map((id) => new Types.ObjectId(id)) },
|
||||
['metadata.userId']: userId
|
||||
},
|
||||
{
|
||||
$set: {
|
||||
['metadata.datasetUsed']: true
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
jsonRes(res, {});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
@@ -3,11 +3,12 @@ import { jsonRes } from '@/service/response';
|
||||
import { authBalanceByUid, authUser } from '@/service/utils/auth';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import { getAIChatApi, axiosConfig } from '@/service/lib/openai';
|
||||
import { pushGenerateVectorBill } from '@/service/events/pushBill';
|
||||
import { pushGenerateVectorBill } from '@/service/common/bill/push';
|
||||
|
||||
type Props = {
|
||||
model: string;
|
||||
input: string[];
|
||||
billId?: string;
|
||||
};
|
||||
type Response = {
|
||||
tokenLen: number;
|
||||
@@ -38,7 +39,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
export async function getVector({
|
||||
model = 'text-embedding-ada-002',
|
||||
userId,
|
||||
input
|
||||
input,
|
||||
billId
|
||||
}: { userId?: string } & Props) {
|
||||
userId && (await authBalanceByUid(userId));
|
||||
|
||||
@@ -82,7 +84,8 @@ export async function getVector({
|
||||
pushGenerateVectorBill({
|
||||
userId,
|
||||
tokenLen: result.tokenLen,
|
||||
model
|
||||
model,
|
||||
billId
|
||||
});
|
||||
|
||||
return result;
|
||||
|
@@ -23,7 +23,7 @@ import { type ChatCompletionRequestMessage } from 'openai';
|
||||
import { TaskResponseKeyEnum } from '@/constants/chat';
|
||||
import { FlowModuleTypeEnum, initModuleType } from '@/constants/flow';
|
||||
import { AppModuleItemType, RunningModuleItemType } from '@/types/app';
|
||||
import { pushTaskBill } from '@/service/events/pushBill';
|
||||
import { pushTaskBill } from '@/service/common/bill/push';
|
||||
import { BillSourceEnum } from '@/constants/user';
|
||||
import { ChatHistoryItemResType } from '@/types/chat';
|
||||
import { UserModelSchema } from '@/types/mongoSchema';
|
||||
|
@@ -38,3 +38,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
});
|
||||
}
|
||||
}
|
||||
export const config = {
|
||||
api: {
|
||||
responseLimit: '32mb'
|
||||
}
|
||||
};
|
||||
|
@@ -84,7 +84,7 @@ const defaultQAModel = {
|
||||
maxToken: 16000,
|
||||
price: 0
|
||||
};
|
||||
const defaultExtractModel: FunctionModelItemType = {
|
||||
export const defaultExtractModel: FunctionModelItemType = {
|
||||
model: 'gpt-3.5-turbo-16k',
|
||||
name: 'GPT35-16k',
|
||||
maxToken: 16000,
|
||||
@@ -92,7 +92,7 @@ const defaultExtractModel: FunctionModelItemType = {
|
||||
prompt: '',
|
||||
functionCall: true
|
||||
};
|
||||
const defaultCQModel: FunctionModelItemType = {
|
||||
export const defaultCQModel: FunctionModelItemType = {
|
||||
model: 'gpt-3.5-turbo-16k',
|
||||
name: 'GPT35-16k',
|
||||
maxToken: 16000,
|
||||
|
Reference in New Issue
Block a user