Add bill of training and rate of file upload (#339)

This commit is contained in:
Archer
2023-09-21 21:02:44 +08:00
committed by GitHub
parent e7e0677291
commit 814c5b3d3c
41 changed files with 401 additions and 263 deletions

View File

@@ -65,7 +65,7 @@
},
"ExtractModel": {
"model": "gpt-3.5-turbo-16k",
"functionCall": false,
"functionCall": true,
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0,
@@ -73,7 +73,7 @@
},
"CQModel": {
"model": "gpt-3.5-turbo-16k",
"functionCall": false,
"functionCall": true,
"name": "GPT35-16k",
"maxToken": 16000,
"price": 0,

View File

@@ -132,7 +132,8 @@
"Confirm to delete the data": "Confirm to delete the data?",
"Export": "Export",
"Queue Desc": "This data refers to the current amount of training for the entire system. FastGPT uses queued training, and if you have too much data to train, you may need to wait for a while",
"System Data Queue": "Data Queue"
"System Data Queue": "Data Queue",
"Training Name": "Dataset Training"
},
"file": {
"Click to download CSV template": "Click to download CSV template",

View File

@@ -132,7 +132,8 @@
"Confirm to delete the data": "确认删除该数据?",
"Export": "导出",
"Queue Desc": "该数据是指整个系统当前待训练的数量。{{title}} 采用排队训练的方式,如果待训练的数据过多,可能需要等待一段时间",
"System Data Queue": "排队长度"
"System Data Queue": "排队长度",
"Training Name": "数据训练"
},
"file": {
"Click to download CSV template": "点击下载 CSV 模板",

3
client/src/api/common/bill/index.d.ts vendored Normal file
View File

@@ -0,0 +1,3 @@
export type CreateTrainingBillType = {
name: string;
};

View File

@@ -0,0 +1,5 @@
import { GET, POST, PUT, DELETE } from '@/api/request';
import { CreateTrainingBillType } from './index.d';
export const postCreateTrainingBill = (data: CreateTrainingBillType) =>
POST<string>(`/common/bill/createTrainingBill`, data);

View File

@@ -7,6 +7,7 @@ export type PushDataProps = {
data: DatasetItemType[];
mode: `${TrainingModeEnum}`;
prompt?: string;
billId?: string;
};
export type PushDataResponse = {
insertLen: number;

View File

@@ -6,3 +6,5 @@ export type GetFileListProps = RequestPaging & {
};
export type UpdateFileProps = { id: string; name?: string; datasetUsed?: boolean };
export type MarkFileUsedProps = { fileIds: string[] };

View File

@@ -2,7 +2,7 @@ import { GET, POST, PUT, DELETE } from '@/api/request';
import type { DatasetFileItemType } from '@/types/core/dataset/file';
import type { GSFileInfoType } from '@/types/common/file';
import type { GetFileListProps, UpdateFileProps } from './file.d';
import type { GetFileListProps, UpdateFileProps, MarkFileUsedProps } from './file.d';
export const getDatasetFiles = (data: GetFileListProps) =>
POST<DatasetFileItemType[]>(`/core/dataset/file/list`, data);
@@ -14,3 +14,6 @@ export const delDatasetEmptyFiles = (kbId: string) =>
DELETE(`/core/dataset/file/delEmptyFiles`, { kbId });
export const updateDatasetFile = (data: UpdateFileProps) => PUT(`/core/dataset/file/update`, data);
export const putMarkFilesUsed = (data: MarkFileUsedProps) =>
PUT(`/core/dataset/file/markUsed`, data);

View File

@@ -5,7 +5,8 @@ export enum OAuthEnum {
export enum BillSourceEnum {
fastgpt = 'fastgpt',
api = 'api',
shareLink = 'shareLink'
shareLink = 'shareLink',
training = 'training'
}
export enum PageTypeEnum {
login = 'login',
@@ -16,7 +17,8 @@ export enum PageTypeEnum {
export const BillSourceMap: Record<`${BillSourceEnum}`, string> = {
[BillSourceEnum.fastgpt]: '在线使用',
[BillSourceEnum.api]: 'Api',
[BillSourceEnum.shareLink]: '免登录链接'
[BillSourceEnum.shareLink]: '免登录链接',
[BillSourceEnum.training]: '数据训练'
};
export enum PromotionEnum {

View File

@@ -1,4 +1,4 @@
import React from 'react';
import React, { useMemo } from 'react';
import {
ModalBody,
Flex,
@@ -20,6 +20,10 @@ import { useTranslation } from 'react-i18next';
const BillDetail = ({ bill, onClose }: { bill: UserBillType; onClose: () => void }) => {
const { t } = useTranslation();
const filterBillList = useMemo(
() => bill.list.filter((item) => item && item.moduleName),
[bill.list]
);
return (
<MyModal isOpen={true} onClose={onClose} title={t('user.Bill Detail')}>
@@ -34,7 +38,7 @@ const BillDetail = ({ bill, onClose }: { bill: UserBillType; onClose: () => void
</Flex>
<Flex alignItems={'center'} pb={4}>
<Box flex={'0 0 80px'}>:</Box>
<Box>{bill.appName}</Box>
<Box>{t(bill.appName) || '-'}</Box>
</Flex>
<Flex alignItems={'center'} pb={4}>
<Box flex={'0 0 80px'}>:</Box>
@@ -59,7 +63,7 @@ const BillDetail = ({ bill, onClose }: { bill: UserBillType; onClose: () => void
</Tr>
</Thead>
<Tbody>
{bill.list.map((item, i) => (
{filterBillList.map((item, i) => (
<Tr key={i}>
<Td>{item.moduleName}</Td>
<Td>{item.model}</Td>

View File

@@ -68,7 +68,7 @@ const BillTable = () => {
<Tr key={item.id}>
<Td>{dayjs(item.time).format('YYYY/MM/DD HH:mm:ss')}</Td>
<Td>{BillSourceMap[item.source]}</Td>
<Td>{item.appName || '-'}</Td>
<Td>{t(item.appName) || '-'}</Td>
<Td>{item.total}</Td>
<Td>
<Button size={'sm'} variant={'base'} onClick={() => setBillDetail(item)}>

View File

@@ -6,7 +6,7 @@ import { sseResponseEventEnum } from '@/constants/chat';
import { sseResponse } from '@/service/utils/tools';
import { AppModuleItemType } from '@/types/app';
import { dispatchModules } from '../openapi/v1/chat/completions';
import { pushTaskBill } from '@/service/events/pushBill';
import { pushTaskBill } from '@/service/common/bill/push';
import { BillSourceEnum } from '@/constants/user';
import { ChatItemType } from '@/types/chat';

View File

@@ -0,0 +1,46 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Bill } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { BillSourceEnum } from '@/constants/user';
import { CreateTrainingBillType } from '@/api/common/bill/index.d';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { name } = req.body as CreateTrainingBillType;
const { userId } = await authUser({ req, authToken: true });
await connectToDatabase();
const { _id } = await Bill.create({
userId,
appName: name,
source: BillSourceEnum.training,
list: [
{
moduleName: '索引生成',
model: 'embedding',
amount: 0,
tokenLen: 0
},
{
moduleName: 'QA 拆分',
model: global.qaModel.name,
amount: 0,
tokenLen: 0
}
],
total: 0
});
jsonRes<string>(res, {
data: _id
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -11,6 +11,7 @@ import { DatasetDataItemType } from '@/types/core/dataset/data';
import { countPromptTokens } from '@/utils/common/tiktoken';
export type Props = {
billId?: string;
kbId: string;
data: DatasetDataItemType;
};
@@ -19,63 +20,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
try {
await connectToDatabase();
const { kbId, data = { q: '', a: '' } } = req.body as Props;
if (!kbId || !data?.q) {
throw new Error('缺少参数');
}
// 凭证校验
const { userId } = await authUser({ req });
// auth kb
const kb = await authKb({ kbId, userId });
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// token check
const token = countPromptTokens(q, 'system');
if (token > getVectorModel(kb.vectorModel).maxToken) {
throw new Error('Over Tokens');
}
const { rows: existsRows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM ${PgDatasetTableName}
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = existsRows[0]?.exists || false;
if (exists) {
throw new Error('已经存在完全一致的数据');
}
const { vectors } = await getVector({
model: kb.vectorModel,
input: [q],
userId
});
const response = await insertData2Dataset({
userId,
kbId,
data: [
{
q,
a,
source: data.source,
vector: vectors[0]
}
]
});
// @ts-ignore
const id = response?.rows?.[0]?.id || '';
jsonRes(res, {
data: id
data: await getVectorAndInsertDataset({
...req.body,
userId
})
});
} catch (err) {
jsonRes(res, {
@@ -84,3 +36,59 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
});
}
});
export async function getVectorAndInsertDataset(
props: Props & { userId: string }
): Promise<string> {
const { kbId, data, userId, billId } = props;
if (!kbId || !data?.q) {
return Promise.reject('缺少参数');
}
// auth kb
const kb = await authKb({ kbId, userId });
const q = data?.q?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
const a = data?.a?.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// token check
const token = countPromptTokens(q, 'system');
if (token > getVectorModel(kb.vectorModel).maxToken) {
return Promise.reject('Over Tokens');
}
const { rows: existsRows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM ${PgDatasetTableName}
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = existsRows[0]?.exists || false;
if (exists) {
return Promise.reject('已经存在完全一致的数据');
}
const { vectors } = await getVector({
model: kb.vectorModel,
input: [q],
userId,
billId
});
const response = await insertData2Dataset({
userId,
kbId,
data: [
{
...data,
q,
a,
vector: vectors[0]
}
]
});
// @ts-ignore
return response?.rows?.[0]?.id || '';
}

View File

@@ -19,7 +19,7 @@ const modeMap = {
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as PushDataProps;
const { kbId, data, mode = TrainingModeEnum.index } = req.body as PushDataProps;
if (!kbId || !Array.isArray(data)) {
throw new Error('KbId or data is empty');
@@ -40,11 +40,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
jsonRes<PushDataResponse>(res, {
data: await pushDataToKb({
kbId,
data,
userId,
mode,
prompt
...req.body,
userId
})
});
} catch (err) {
@@ -60,7 +57,8 @@ export async function pushDataToKb({
kbId,
data,
mode,
prompt
prompt,
billId
}: { userId: string } & PushDataProps): Promise<PushDataResponse> {
const [kb, vectorModel] = await Promise.all([
authKb({
@@ -150,6 +148,7 @@ export async function pushDataToKb({
kbId,
mode,
prompt,
billId,
vectorModel: vectorModel.model
}))
);
@@ -163,6 +162,9 @@ export async function pushDataToKb({
export const config = {
api: {
bodyParser: {
sizeLimit: '10mb'
},
responseLimit: '12mb'
}
};

View File

@@ -1,10 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, KB, App, TrainingData } from '@/service/mongo';
import { connectToDatabase, KB, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { PgDatasetTableName } from '@/constants/plugin';
import { GridFSStorage } from '@/service/lib/gridfs';
import { Types } from 'mongoose';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -25,7 +26,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// delete training data
await TrainingData.deleteMany({
userId,
kbId: { $in: deletedIds }
kbId: { $in: deletedIds.map((id) => new Types.ObjectId(id)) }
});
// delete all pg data

View File

@@ -0,0 +1,38 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { GridFSStorage } from '@/service/lib/gridfs';
import { MarkFileUsedProps } from '@/api/core/dataset/file.d';
import { Types } from 'mongoose';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
await connectToDatabase();
const { fileIds } = req.body as MarkFileUsedProps;
const { userId } = await authUser({ req, authToken: true });
const gridFs = new GridFSStorage('dataset', userId);
const collection = gridFs.Collection();
await collection.updateMany(
{
_id: { $in: fileIds.map((id) => new Types.ObjectId(id)) },
['metadata.userId']: userId
},
{
$set: {
['metadata.datasetUsed']: true
}
}
);
jsonRes(res, {});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -3,11 +3,12 @@ import { jsonRes } from '@/service/response';
import { authBalanceByUid, authUser } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
import { getAIChatApi, axiosConfig } from '@/service/lib/openai';
import { pushGenerateVectorBill } from '@/service/events/pushBill';
import { pushGenerateVectorBill } from '@/service/common/bill/push';
type Props = {
model: string;
input: string[];
billId?: string;
};
type Response = {
tokenLen: number;
@@ -38,7 +39,8 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function getVector({
model = 'text-embedding-ada-002',
userId,
input
input,
billId
}: { userId?: string } & Props) {
userId && (await authBalanceByUid(userId));
@@ -82,7 +84,8 @@ export async function getVector({
pushGenerateVectorBill({
userId,
tokenLen: result.tokenLen,
model
model,
billId
});
return result;

View File

@@ -23,7 +23,7 @@ import { type ChatCompletionRequestMessage } from 'openai';
import { TaskResponseKeyEnum } from '@/constants/chat';
import { FlowModuleTypeEnum, initModuleType } from '@/constants/flow';
import { AppModuleItemType, RunningModuleItemType } from '@/types/app';
import { pushTaskBill } from '@/service/events/pushBill';
import { pushTaskBill } from '@/service/common/bill/push';
import { BillSourceEnum } from '@/constants/user';
import { ChatHistoryItemResType } from '@/types/chat';
import { UserModelSchema } from '@/types/mongoSchema';

View File

@@ -38,3 +38,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
});
}
}
export const config = {
api: {
responseLimit: '32mb'
}
};

View File

@@ -84,7 +84,7 @@ const defaultQAModel = {
maxToken: 16000,
price: 0
};
const defaultExtractModel: FunctionModelItemType = {
export const defaultExtractModel: FunctionModelItemType = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,
@@ -92,7 +92,7 @@ const defaultExtractModel: FunctionModelItemType = {
prompt: '',
functionCall: true
};
const defaultCQModel: FunctionModelItemType = {
export const defaultCQModel: FunctionModelItemType = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,

View File

@@ -26,7 +26,7 @@ import { QuestionOutlineIcon } from '@chakra-ui/icons';
import { TrainingModeEnum } from '@/constants/plugin';
import FileSelect, { type FileItemType } from './FileSelect';
import { useDatasetStore } from '@/store/dataset';
import { updateDatasetFile } from '@/api/core/dataset/file';
import { putMarkFilesUsed } from '@/api/core/dataset/file';
import { chunksUpload } from '@/utils/web/core/dataset';
const fileExtension = '.txt, .doc, .docx, .pdf, .md';
@@ -66,14 +66,7 @@ const ChunkImport = ({ kbId }: { kbId: string }) => {
const chunks = files.map((file) => file.chunks).flat();
// mark the file is used
await Promise.all(
files.map((file) =>
updateDatasetFile({
id: file.id,
datasetUsed: true
})
)
);
await putMarkFilesUsed({ fileIds: files.map((file) => file.id) });
// upload data
const { insertLen } = await chunksUpload({

View File

@@ -10,7 +10,7 @@ import { TrainingModeEnum } from '@/constants/plugin';
import FileSelect, { type FileItemType } from './FileSelect';
import { useRouter } from 'next/router';
import { useDatasetStore } from '@/store/dataset';
import { updateDatasetFile } from '@/api/core/dataset/file';
import { putMarkFilesUsed } from '@/api/core/dataset/file';
import { chunksUpload } from '@/utils/web/core/dataset';
const fileExtension = '.csv';
@@ -39,14 +39,7 @@ const CsvImport = ({ kbId }: { kbId: string }) => {
const { mutate: onclickUpload, isLoading: uploading } = useMutation({
mutationFn: async () => {
// mark the file is used
await Promise.all(
files.map((file) =>
updateDatasetFile({
id: file.id,
datasetUsed: true
})
)
);
await putMarkFilesUsed({ fileIds: files.map((file) => file.id) });
const chunks = files
.map((file) => file.chunks)

View File

@@ -183,7 +183,7 @@ const FileSelect = ({
}
setSelectingText(undefined);
},
[chunkLen, onPushFiles, t, toast]
[chunkLen, kbDetail._id, onPushFiles, t, toast]
);
const onUrlFetch = useCallback(
(e: FetchResultItem[]) => {

View File

@@ -15,7 +15,7 @@ import { QuestionOutlineIcon, InfoOutlineIcon } from '@chakra-ui/icons';
import { TrainingModeEnum } from '@/constants/plugin';
import FileSelect, { type FileItemType } from './FileSelect';
import { useRouter } from 'next/router';
import { updateDatasetFile } from '@/api/core/dataset/file';
import { putMarkFilesUsed } from '@/api/core/dataset/file';
import { Prompt_AgentQA } from '@/prompts/core/agent';
import { replaceVariable } from '@/utils/common/tools/text';
import { chunksUpload } from '@/utils/web/core/dataset';
@@ -65,14 +65,7 @@ const QAImport = ({ kbId }: { kbId: string }) => {
const chunks = files.map((file) => file.chunks).flat();
// mark the file is used
await Promise.all(
files.map((file) =>
updateDatasetFile({
id: file.id,
datasetUsed: true
})
)
);
await putMarkFilesUsed({ fileIds: files.map((file) => file.id) });
// upload data
const { insertLen } = await chunksUpload({
@@ -80,6 +73,7 @@ const QAImport = ({ kbId }: { kbId: string }) => {
chunks,
mode: TrainingModeEnum.qa,
prompt: previewQAPrompt,
rate: 10,
onUploading: (insertLen) => {
setSuccessChunks(insertLen);
}

View File

@@ -1,9 +1,60 @@
import { connectToDatabase, Bill, User, OutLink } from '../mongo';
import { Bill, User, OutLink } from '@/service/mongo';
import { BillSourceEnum } from '@/constants/user';
import { getModel } from '../utils/data';
import { getModel } from '@/service/utils/data';
import { ChatHistoryItemResType } from '@/types/chat';
import { formatPrice } from '@/utils/user';
import { addLog } from '../utils/tools';
import { addLog } from '@/service/utils/tools';
import type { BillListItemType, CreateBillType } from '@/types/common/bill';
async function createBill(data: CreateBillType) {
try {
await Promise.all([
User.findByIdAndUpdate(data.userId, {
$inc: { balance: -data.total }
}),
Bill.create(data)
]);
} catch (error) {
addLog.error(`createBill error`, error);
}
}
async function concatBill({
billId,
total,
listIndex,
tokens = 0,
userId
}: {
billId?: string;
total: number;
listIndex?: number;
tokens?: number;
userId: string;
}) {
if (!billId) return;
try {
await Promise.all([
Bill.findOneAndUpdate(
{
_id: billId,
userId
},
{
$inc: {
total,
...(listIndex !== undefined && {
[`list.${listIndex}.amount`]: total,
[`list.${listIndex}.tokenLen`]: tokens
})
}
}
),
User.findByIdAndUpdate(userId, {
$inc: { balance: -total }
})
]);
} catch (error) {}
}
export const pushTaskBill = async ({
appName,
@@ -24,7 +75,7 @@ export const pushTaskBill = async ({
const total = response.reduce((sum, item) => sum + item.price, 0);
await Promise.allSettled([
Bill.create({
createBill({
userId,
appName,
appId,
@@ -37,9 +88,6 @@ export const pushTaskBill = async ({
tokenLen: item.tokens
}))
}),
User.findByIdAndUpdate(userId, {
$inc: { balance: -total }
}),
...(shareId
? [
updateShareChatBill({
@@ -83,71 +131,66 @@ export const updateShareChatBill = async ({
export const pushQABill = async ({
userId,
totalTokens,
appName
billId
}: {
userId: string;
totalTokens: number;
appName: string;
billId: string;
}) => {
addLog.info('splitData generate success', { totalTokens });
let billId;
try {
await connectToDatabase();
// 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = global.qaModel.price || 3;
// 计算价格
const total = unitPrice * totalTokens;
// 插入 Bill 记录
const res = await Bill.create({
concatBill({
billId,
userId,
appName,
tokenLen: totalTokens,
total
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -total }
total,
tokens: totalTokens,
listIndex: 1
});
} catch (err) {
addLog.error('Create completions bill error', err);
billId && Bill.findByIdAndDelete(billId);
}
};
export const pushGenerateVectorBill = async ({
billId,
userId,
tokenLen,
model
}: {
billId?: string;
userId: string;
tokenLen: number;
model: string;
}) => {
let billId;
try {
await connectToDatabase();
// 计算价格. 至少为1
const vectorModel =
global.vectorModels.find((item) => item.model === model) || global.vectorModels[0];
const unitPrice = vectorModel.price || 0.2;
let total = unitPrice * tokenLen;
total = total > 1 ? total : 1;
try {
// 计算价格. 至少为1
const vectorModel =
global.vectorModels.find((item) => item.model === model) || global.vectorModels[0];
const unitPrice = vectorModel.price || 0.2;
let total = unitPrice * tokenLen;
total = total > 1 ? total : 1;
// 插入 Bill 记录
const res = await Bill.create({
// 插入 Bill 记录
if (billId) {
concatBill({
userId,
total,
billId,
tokens: tokenLen,
listIndex: 0
});
} else {
createBill({
userId,
model: vectorModel.model,
appName: '索引生成',
total,
source: BillSourceEnum.fastgpt,
list: [
{
moduleName: '索引生成',
@@ -157,18 +200,9 @@ export const pushGenerateVectorBill = async ({
}
]
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -total }
});
} catch (err) {
addLog.error('Create generateVector bill error', err);
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) {
console.log(error);
} catch (err) {
addLog.error('Create generateVector bill error', err);
}
};

View File

@@ -1,6 +1,6 @@
import { Schema, model, models, Model } from 'mongoose';
import { BillSchema as BillType } from '@/types/mongoSchema';
import { BillSourceEnum, BillSourceMap } from '@/constants/user';
import { BillSchema as BillType } from '@/types/common/bill';
import { BillSourceMap } from '@/constants/user';
const BillSchema = new Schema({
userId: {
@@ -28,7 +28,7 @@ const BillSchema = new Schema({
source: {
type: String,
enum: Object.keys(BillSourceMap),
default: BillSourceEnum.fastgpt
required: true
},
list: {
type: Array,

View File

@@ -1,18 +1,16 @@
import { TrainingData } from '@/service/mongo';
import { pushQABill } from '@/service/events/pushBill';
import { pushDataToKb } from '@/pages/api/core/dataset/data/pushData';
import { pushQABill } from '@/service/common/bill/push';
import { TrainingModeEnum } from '@/constants/plugin';
import { ERROR_ENUM } from '../errorCode';
import { sendInform } from '@/pages/api/user/inform/send';
import { authBalanceByUid } from '../utils/auth';
import { axiosConfig, getAIChatApi } from '../lib/openai';
import { ChatCompletionRequestMessage } from 'openai';
import { gptMessage2ChatType } from '@/utils/adapt';
import { addLog } from '../utils/tools';
import { splitText2Chunks } from '@/utils/file';
import { countMessagesTokens } from '@/utils/common/tiktoken';
import { replaceVariable } from '@/utils/common/tools/text';
import { Prompt_AgentQA } from '@/prompts/core/agent';
import { pushDataToKb } from '@/pages/api/core/dataset/data/pushData';
const reduceQueue = () => {
global.qaQueueLen = global.qaQueueLen > 0 ? global.qaQueueLen - 1 : 0;
@@ -41,7 +39,8 @@ export async function generateQA(): Promise<any> {
prompt: 1,
q: 1,
source: 1,
file_id: 1
file_id: 1,
billId: 1
});
// task preemption
@@ -61,89 +60,67 @@ export async function generateQA(): Promise<any> {
const chatAPI = getAIChatApi();
// 请求 chatgpt 获取回答
const response = await Promise.all(
[data.q].map((text) => {
const messages: ChatCompletionRequestMessage[] = [
{
role: 'user',
content: data.prompt
? replaceVariable(data.prompt, { text })
: replaceVariable(Prompt_AgentQA.prompt, {
theme: Prompt_AgentQA.defaultTheme,
text
})
}
];
const modelTokenLimit = global.qaModel.maxToken || 16000;
const promptsToken = countMessagesTokens({
messages: gptMessage2ChatType(messages)
});
const maxToken = modelTokenLimit - promptsToken;
// request LLM to get QA
const text = data.q;
const messages: ChatCompletionRequestMessage[] = [
{
role: 'user',
content: data.prompt
? replaceVariable(data.prompt, { text })
: replaceVariable(Prompt_AgentQA.prompt, {
theme: Prompt_AgentQA.defaultTheme,
text
})
}
];
return chatAPI
.createChatCompletion(
{
model: global.qaModel.model,
temperature: 0.01,
messages,
stream: false,
max_tokens: maxToken
},
{
timeout: 480000,
...axiosConfig()
}
)
.then((res) => {
const answer = res.data.choices?.[0].message?.content;
const totalTokens = res.data.usage?.total_tokens || 0;
const result = formatSplitText(answer || ''); // 格式化后的QA对
console.log(`split result length: `, result.length);
// 计费
if (result.length > 0) {
pushQABill({
userId: data.userId,
totalTokens,
appName: 'QA 拆分'
});
} else {
addLog.info(`QA result 0:`, { answer });
}
return {
rawContent: answer,
result
};
})
.catch((err) => {
console.log('QA拆分错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
return Promise.reject(err);
});
})
const { data: chatResponse } = await chatAPI.createChatCompletion(
{
model: global.qaModel.model,
temperature: 0.01,
messages,
stream: false
},
{
timeout: 480000,
...axiosConfig()
}
);
const answer = chatResponse.choices?.[0].message?.content;
const totalTokens = chatResponse.usage?.total_tokens || 0;
const responseList = response.map((item) => item.result).flat();
const qaArr = formatSplitText(answer || ''); // 格式化后的QA对
// 创建 向量生成 队列
// get vector and insert
await pushDataToKb({
kbId,
data: responseList.map((item) => ({
data: qaArr.map((item) => ({
...item,
source: data.source,
file_id: data.file_id
})),
userId,
mode: TrainingModeEnum.index
mode: TrainingModeEnum.index,
billId: data.billId
});
// delete data from training
await TrainingData.findByIdAndDelete(data._id);
console.log(`split result length: `, qaArr.length);
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
// 计费
if (qaArr.length > 0) {
pushQABill({
userId: data.userId,
totalTokens,
billId: data.billId
});
} else {
addLog.info(`QA result 0:`, { answer });
}
reduceQueue();
generateQA();
} catch (err: any) {

View File

@@ -39,7 +39,8 @@ export async function generateVector(): Promise<any> {
a: 1,
source: 1,
file_id: 1,
vectorModel: 1
vectorModel: 1,
billId: 1
});
// task preemption
@@ -64,7 +65,8 @@ export async function generateVector(): Promise<any> {
const { vectors } = await getVector({
model: data.vectorModel,
input: dataItems.map((item) => item.q),
userId
userId,
billId: data.billId
});
// 生成结果插入到 pg

View File

@@ -53,6 +53,10 @@ const TrainingDataSchema = new Schema({
file_id: {
type: String,
default: ''
},
billId: {
type: String,
default: ''
}
});

View File

@@ -10,6 +10,7 @@ import { FlowModuleTypeEnum } from '@/constants/flow';
import { ModuleDispatchProps } from '@/types/core/modules';
import { replaceVariable } from '@/utils/common/tools/text';
import { Prompt_CQJson } from '@/prompts/core/agent';
import { defaultCQModel } from '@/pages/api/system/getInitData';
type Props = ModuleDispatchProps<{
systemPrompt?: string;
@@ -36,7 +37,7 @@ export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse
return Promise.reject('Input is empty');
}
const cqModel = global.cqModel;
const cqModel = global.cqModel || defaultCQModel;
const { arg, tokens } = await (async () => {
if (cqModel.functionCall) {
@@ -156,7 +157,7 @@ Human:${userChatInput}`
},
{
timeout: 480000,
...axiosConfig()
...axiosConfig(userOpenaiAccount)
}
);
const answer = data.choices?.[0].message?.content || '';

View File

@@ -9,6 +9,7 @@ import { FlowModuleTypeEnum } from '@/constants/flow';
import { ModuleDispatchProps } from '@/types/core/modules';
import { Prompt_ExtractJson } from '@/prompts/core/agent';
import { replaceVariable } from '@/utils/common/tools/text';
import { defaultExtractModel } from '@/pages/api/system/getInitData';
type Props = ModuleDispatchProps<{
history?: ChatItemType[];
@@ -36,7 +37,7 @@ export async function dispatchContentExtract(props: Props): Promise<Response> {
return Promise.reject('Input is empty');
}
const extractModel = global.extractModel;
const extractModel = global.extractModel || defaultExtractModel;
const { arg, tokens } = await (async () => {
if (extractModel.functionCall) {
@@ -191,7 +192,7 @@ Human: ${content}`
},
{
timeout: 480000,
...axiosConfig()
...axiosConfig(userOpenaiAccount)
}
);
const answer = data.choices?.[0].message?.content || '';

View File

@@ -8,7 +8,7 @@ import { textAdaptGptResponse } from '@/utils/adapt';
import { getAIChatApi, axiosConfig } from '@/service/lib/openai';
import { TaskResponseKeyEnum } from '@/constants/chat';
import { getChatModel } from '@/service/utils/data';
import { countModelPrice } from '@/service/events/pushBill';
import { countModelPrice } from '@/service/common/bill/push';
import { ChatModelItemType } from '@/types/model';
import { textCensor } from '@/api/service/plugins';
import { ChatCompletionRequestMessageRoleEnum } from 'openai';

View File

@@ -2,7 +2,7 @@ import { PgClient } from '@/service/pg';
import type { ChatHistoryItemResType } from '@/types/chat';
import { TaskResponseKeyEnum } from '@/constants/chat';
import { getVector } from '@/pages/api/openapi/plugin/vector';
import { countModelPrice } from '@/service/events/pushBill';
import { countModelPrice } from '@/service/common/bill/push';
import type { SelectedDatasetType } from '@/types/core/dataset';
import type { QuoteItemType } from '@/types/chat';
import { PgDatasetTableName } from '@/constants/plugin';

View File

@@ -130,7 +130,7 @@ export * from './models/chat';
export * from './models/chatItem';
export * from './models/app';
export * from './models/user';
export * from './models/bill';
export * from './common/bill/schema';
export * from './models/pay';
export * from './models/trainingData';
export * from './models/openapi';

View File

@@ -178,7 +178,7 @@ export const insertData2Dataset = ({
values: data.map((item) => [
{ key: 'user_id', value: userId },
{ key: 'kb_id', value: kbId },
{ key: 'source', value: item.source?.slice(0, 30)?.trim() || '' },
{ key: 'source', value: item.source?.slice(0, 60)?.trim() || '' },
{ key: 'file_id', value: item.file_id || '' },
{ key: 'q', value: item.q.replace(/'/g, '"') },
{ key: 'a', value: item.a.replace(/'/g, '"') },

23
client/src/types/common/bill.d.ts vendored Normal file
View File

@@ -0,0 +1,23 @@
import { BillSourceEnum } from '@/constants/user';
import type { BillListItemType } from '@/types/common/bill';
export type BillListItemType = {
moduleName: string;
amount: number;
model?: string;
tokenLen?: number;
};
export type CreateBillType = {
userId: string;
appName: string;
appId?: string;
total: number;
source: `${BillSourceEnum}`;
list: BillListItemType[];
};
export type BillSchema = CreateBillType & {
_id: string;
time: Date;
};

View File

@@ -1,7 +1,7 @@
import type { ChatItemType } from './chat';
import { ModelNameEnum, ChatModelType, EmbeddingModelType } from '@/constants/model';
import type { DataType } from './data';
import { BillSourceEnum, InformTypeEnum } from '@/constants/user';
import { InformTypeEnum } from '@/constants/user';
import { TrainingModeEnum } from '@/constants/plugin';
import type { AppModuleItemType } from './app';
import { ChatSourceEnum } from '@/constants/chat';
@@ -70,6 +70,7 @@ export interface TrainingDataSchema {
a: string;
source: string;
file_id: string;
billId: string;
}
export interface ChatSchema {
@@ -102,23 +103,6 @@ export interface ChatItemSchema extends ChatItemType {
};
}
export type BillListItemType = {
moduleName: string;
amount: number;
model?: string;
tokenLen?: number;
};
export interface BillSchema {
_id: string;
userId: string;
appName: string;
appId?: string;
source: `${BillSourceEnum}`;
time: Date;
total: number;
list: BillListItemType[];
}
export interface PaySchema {
_id: string;
userId: string;

View File

@@ -1,5 +1,7 @@
import { BillSourceEnum } from '@/constants/user';
import type { BillSchema, UserModelSchema } from './mongoSchema';
import type { UserModelSchema } from './mongoSchema';
import type { BillSchema } from '@/types/common/bill';
export interface UserType {
_id: string;
username: string;

View File

@@ -1,5 +1,5 @@
import { formatPrice } from './user';
import type { BillSchema } from '../types/mongoSchema';
import type { BillSchema } from '@/types/common/bill';
import type { UserBillType } from '@/types/user';
import { ChatItemType } from '@/types/chat';
import { ChatCompletionRequestMessageRoleEnum } from 'openai';

View File

@@ -1,3 +1,4 @@
import { postCreateTrainingBill } from '@/api/common/bill';
import { postChunks2Dataset } from '@/api/core/dataset/data';
import { TrainingModeEnum } from '@/constants/plugin';
import type { DatasetDataItemType } from '@/types/core/dataset/data';
@@ -8,7 +9,7 @@ export async function chunksUpload({
mode,
chunks,
prompt,
rate = 200,
rate = 50,
onUploading
}: {
kbId: string;
@@ -18,12 +19,16 @@ export async function chunksUpload({
rate?: number;
onUploading?: (insertLen: number, total: number) => void;
}) {
// create training bill
const billId = await postCreateTrainingBill({ name: 'dataset.Training Name' });
async function upload(data: DatasetDataItemType[]) {
return postChunks2Dataset({
kbId,
data,
mode,
prompt
prompt,
billId
});
}