mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 21:13:50 +00:00
training queue
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
# proxy
|
||||
# AXIOS_PROXY_HOST=127.0.0.1
|
||||
# AXIOS_PROXY_PORT=7890
|
||||
# 是否开启队列任务。 1-开启,0-关闭(请求parentUrl去执行任务,单机时直接填1)
|
||||
queueTask=1
|
||||
parentUrl=https://hostname/api/openapi/startEvents
|
||||
# email
|
||||
MY_MAIL=xxx@qq.com
|
||||
MAILE_CODE=xxx
|
||||
@@ -21,7 +18,8 @@ SENSITIVE_CHECK=1
|
||||
# openai
|
||||
# OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
# OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉)
|
||||
OPENAIKEY=sk-xxx
|
||||
OPENAIKEY=sk-xxx # 对话用的key
|
||||
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
|
||||
GPT4KEY=sk-xxx
|
||||
# claude
|
||||
CLAUDE_BASE_URL=calude模型请求地址
|
||||
|
@@ -39,9 +39,6 @@ services:
|
||||
# proxy(可选)
|
||||
- AXIOS_PROXY_HOST=127.0.0.1
|
||||
- AXIOS_PROXY_PORT=7890
|
||||
# 是否开启队列任务。 1-开启,0-关闭(请求 parentUrl 去执行任务,单机时直接填1)
|
||||
- queueTask=1
|
||||
- parentUrl=https://hostname/api/openapi/startEvents
|
||||
# 发送邮箱验证码配置。用的是QQ邮箱。参考 nodeMail 获取MAILE_CODE,自行百度。
|
||||
- MY_MAIL=xxxx@qq.com
|
||||
- MAILE_CODE=xxxx
|
||||
@@ -66,7 +63,8 @@ services:
|
||||
- PG_PASSWORD=1234 # POSTGRES_PASSWORD
|
||||
- PG_DB_NAME=fastgpt # POSTGRES_DB
|
||||
# openai
|
||||
- OPENAIKEY=sk-xxxxx
|
||||
- OPENAIKEY=sk-xxxxx # 对话用的key
|
||||
- OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
|
||||
- GPT4KEY=sk-xxx
|
||||
- OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
- OPENAI_BASE_URL_AUTH=可选的安全凭证
|
||||
|
@@ -36,7 +36,6 @@ mongo pg
|
||||
AXIOS_PROXY_HOST=127.0.0.1
|
||||
AXIOS_PROXY_PORT_FAST=7890
|
||||
AXIOS_PROXY_PORT_NORMAL=7890
|
||||
queueTask=1
|
||||
# email
|
||||
MY_MAIL= {Your Mail}
|
||||
MAILE_CODE={Yoir Mail code}
|
||||
@@ -48,7 +47,8 @@ aliTemplateCode=SMS_xxx
|
||||
# token
|
||||
TOKEN_KEY=sswada
|
||||
# openai
|
||||
OPENAIKEY={Your openapi key}
|
||||
OPENAIKEY=sk-xxx # 对话用的key
|
||||
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
|
||||
# db
|
||||
MONGODB_URI=mongodb://username:password@0.0.0.0:27017/test?authSource=admin
|
||||
PG_HOST=0.0.0.0
|
||||
|
@@ -10,9 +10,6 @@
|
||||
# proxy(可选)
|
||||
AXIOS_PROXY_HOST=127.0.0.1
|
||||
AXIOS_PROXY_PORT=7890
|
||||
# 是否开启队列任务。 1-开启,0-关闭(请求parentUrl去执行任务,单机时直接填1)
|
||||
queueTask=1
|
||||
parentUrl=https://hostname/api/openapi/startEvents
|
||||
# email
|
||||
MY_MAIL=xxx@qq.com
|
||||
MAILE_CODE=xxx
|
||||
@@ -30,7 +27,8 @@ SENSITIVE_CHECK=1
|
||||
# openai
|
||||
# OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
# OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉)
|
||||
OPENAIKEY=sk-xxx
|
||||
OPENAIKEY=sk-xxx # 对话用的key
|
||||
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
|
||||
GPT4KEY=sk-xxx
|
||||
# claude
|
||||
CLAUDE_BASE_URL=calude模型请求地址
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import { GET, POST, PUT, DELETE } from '../request';
|
||||
import type { KbItemType } from '@/types/plugin';
|
||||
import { RequestPaging } from '@/types/index';
|
||||
import { SplitTextTypEnum } from '@/constants/plugin';
|
||||
import { TrainingTypeEnum } from '@/constants/plugin';
|
||||
import { KbDataItemType } from '@/types/plugin';
|
||||
|
||||
export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string };
|
||||
@@ -34,11 +34,11 @@ export const getExportDataList = (kbId: string) =>
|
||||
/**
|
||||
* 获取模型正在拆分数据的数量
|
||||
*/
|
||||
export const getTrainingData = (kbId: string) =>
|
||||
GET<{
|
||||
splitDataQueue: number;
|
||||
embeddingQueue: number;
|
||||
}>(`/plugins/kb/data/getTrainingData?kbId=${kbId}`);
|
||||
export const getTrainingData = (data: { kbId: string; init: boolean }) =>
|
||||
POST<{
|
||||
qaListLen: number;
|
||||
vectorListLen: number;
|
||||
}>(`/plugins/kb/data/getTrainingData`, data);
|
||||
|
||||
export const getKbDataItemById = (dataId: string) =>
|
||||
GET(`/plugins/kb/data/getDataById`, { dataId });
|
||||
@@ -69,5 +69,5 @@ export const postSplitData = (data: {
|
||||
kbId: string;
|
||||
chunks: string[];
|
||||
prompt: string;
|
||||
mode: `${SplitTextTypEnum}`;
|
||||
mode: `${TrainingTypeEnum}`;
|
||||
}) => POST(`/openapi/text/splitText`, data);
|
||||
|
@@ -108,27 +108,27 @@ export const ModelDataStatusMap: Record<`${ModelDataStatusEnum}`, string> = {
|
||||
|
||||
/* 知识库搜索时的配置 */
|
||||
// 搜索方式
|
||||
export enum ModelVectorSearchModeEnum {
|
||||
export enum appVectorSearchModeEnum {
|
||||
hightSimilarity = 'hightSimilarity', // 高相似度+禁止回复
|
||||
lowSimilarity = 'lowSimilarity', // 低相似度
|
||||
noContext = 'noContex' // 高相似度+无上下文回复
|
||||
}
|
||||
export const ModelVectorSearchModeMap: Record<
|
||||
`${ModelVectorSearchModeEnum}`,
|
||||
`${appVectorSearchModeEnum}`,
|
||||
{
|
||||
text: string;
|
||||
similarity: number;
|
||||
}
|
||||
> = {
|
||||
[ModelVectorSearchModeEnum.hightSimilarity]: {
|
||||
[appVectorSearchModeEnum.hightSimilarity]: {
|
||||
text: '高相似度, 无匹配时拒绝回复',
|
||||
similarity: 0.18
|
||||
},
|
||||
[ModelVectorSearchModeEnum.noContext]: {
|
||||
[appVectorSearchModeEnum.noContext]: {
|
||||
text: '高相似度,无匹配时直接回复',
|
||||
similarity: 0.18
|
||||
},
|
||||
[ModelVectorSearchModeEnum.lowSimilarity]: {
|
||||
[appVectorSearchModeEnum.lowSimilarity]: {
|
||||
text: '低相似度匹配',
|
||||
similarity: 0.7
|
||||
}
|
||||
@@ -143,7 +143,7 @@ export const defaultModel: ModelSchema = {
|
||||
updateTime: Date.now(),
|
||||
chat: {
|
||||
relatedKbs: [],
|
||||
searchMode: ModelVectorSearchModeEnum.hightSimilarity,
|
||||
searchMode: appVectorSearchModeEnum.hightSimilarity,
|
||||
systemPrompt: '',
|
||||
temperature: 0,
|
||||
chatModel: OpenAiChatEnum.GPT35
|
||||
|
@@ -1,14 +1,4 @@
|
||||
export enum SplitTextTypEnum {
|
||||
export enum TrainingTypeEnum {
|
||||
'qa' = 'qa',
|
||||
'subsection' = 'subsection'
|
||||
}
|
||||
|
||||
export enum PluginTypeEnum {
|
||||
LLM = 'LLM',
|
||||
Text = 'Text',
|
||||
Function = 'Function'
|
||||
}
|
||||
|
||||
export enum PluginParamsTypeEnum {
|
||||
'Text' = 'text'
|
||||
}
|
||||
|
@@ -5,7 +5,7 @@ import { PgClient } from '@/service/pg';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { appVectorSearchModeEnum } from '@/constants/model';
|
||||
import { authModel } from '@/service/utils/auth';
|
||||
import { ChatModelMap } from '@/constants/model';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
@@ -92,7 +92,8 @@ export async function appKbSearch({
|
||||
// get vector
|
||||
const promptVectors = await openaiEmbedding({
|
||||
userId,
|
||||
input
|
||||
input,
|
||||
type: 'chat'
|
||||
});
|
||||
|
||||
// search kb
|
||||
@@ -138,7 +139,7 @@ export async function appKbSearch({
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
}
|
||||
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext
|
||||
: model.chat.searchMode === appVectorSearchModeEnum.noContext
|
||||
? {
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
|
||||
@@ -176,7 +177,7 @@ export async function appKbSearch({
|
||||
const systemPrompt = sliceResult.flat().join('\n').trim();
|
||||
|
||||
/* 高相似度+不回复 */
|
||||
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) {
|
||||
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
|
||||
return {
|
||||
code: 201,
|
||||
rawSearch: [],
|
||||
@@ -190,7 +191,7 @@ export async function appKbSearch({
|
||||
};
|
||||
}
|
||||
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
|
||||
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) {
|
||||
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) {
|
||||
return {
|
||||
code: 200,
|
||||
rawSearch: [],
|
||||
|
@@ -1,84 +1,36 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import type { KbDataItemType } from '@/types/plugin';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { connectToDatabase, TrainingData } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
import { PgClient, insertKbItem } from '@/service/pg';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { authKb } from '@/service/utils/auth';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
|
||||
interface Props {
|
||||
kbId: string;
|
||||
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
|
||||
}
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const {
|
||||
kbId,
|
||||
data,
|
||||
formatLineBreak = true
|
||||
} = req.body as {
|
||||
kbId: string;
|
||||
formatLineBreak?: boolean;
|
||||
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
|
||||
};
|
||||
const { kbId, data } = req.body as Props;
|
||||
|
||||
if (!kbId || !Array.isArray(data)) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 凭证校验
|
||||
const { userId } = await authUser({ req });
|
||||
|
||||
await authKb({
|
||||
userId,
|
||||
kbId
|
||||
});
|
||||
|
||||
// 过滤重复的内容
|
||||
const searchRes = await Promise.allSettled(
|
||||
data.map(async ({ q, a = '' }) => {
|
||||
if (!q) {
|
||||
return Promise.reject('q为空');
|
||||
}
|
||||
|
||||
if (formatLineBreak) {
|
||||
q = q.replace(/\\n/g, '\n');
|
||||
a = a.replace(/\\n/g, '\n');
|
||||
}
|
||||
|
||||
// Exactly the same data, not push
|
||||
try {
|
||||
const count = await PgClient.count('modelData', {
|
||||
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
|
||||
});
|
||||
if (count > 0) {
|
||||
return Promise.reject('已经存在');
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
return Promise.resolve({
|
||||
q,
|
||||
a
|
||||
});
|
||||
})
|
||||
);
|
||||
const filterData = searchRes
|
||||
.filter((item) => item.status === 'fulfilled')
|
||||
.map<{ q: string; a: string }>((item: any) => item.value);
|
||||
|
||||
// 插入记录
|
||||
const insertRes = await insertKbItem({
|
||||
userId,
|
||||
kbId,
|
||||
data: filterData
|
||||
});
|
||||
|
||||
generateVector();
|
||||
|
||||
jsonRes(res, {
|
||||
message: `共插入 ${insertRes.rowCount} 条数据`,
|
||||
data: insertRes.rowCount
|
||||
data: await pushDataToKb({
|
||||
kbId,
|
||||
data,
|
||||
userId
|
||||
})
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
@@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
}
|
||||
});
|
||||
|
||||
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) {
|
||||
await authKb({
|
||||
userId,
|
||||
kbId
|
||||
});
|
||||
|
||||
if (data.length === 0) {
|
||||
return {
|
||||
trainingId: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 插入记录
|
||||
const { _id } = await TrainingData.create({
|
||||
userId,
|
||||
kbId,
|
||||
vectorList: data
|
||||
});
|
||||
|
||||
generateVector(_id);
|
||||
|
||||
return {
|
||||
trainingId: _id
|
||||
};
|
||||
}
|
||||
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: {
|
||||
|
@@ -5,10 +5,11 @@ import { ModelDataStatusEnum } from '@/constants/model';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import { openaiEmbedding } from '../plugin/openaiEmbedding';
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
|
||||
const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string };
|
||||
|
||||
if (!dataId) {
|
||||
throw new Error('缺少参数');
|
||||
@@ -17,22 +18,24 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
// 凭证校验
|
||||
const { userId } = await authUser({ req });
|
||||
|
||||
// get vector
|
||||
const vector = await (async () => {
|
||||
if (q) {
|
||||
return openaiEmbedding({
|
||||
userId,
|
||||
input: [q],
|
||||
type: 'chat'
|
||||
});
|
||||
}
|
||||
return [];
|
||||
})();
|
||||
|
||||
// 更新 pg 内容.仅修改a,不需要更新向量。
|
||||
await PgClient.update('modelData', {
|
||||
where: [['id', dataId], 'AND', ['user_id', userId]],
|
||||
values: [
|
||||
{ key: 'a', value: a },
|
||||
...(q
|
||||
? [
|
||||
{ key: 'q', value: q },
|
||||
{ key: 'status', value: ModelDataStatusEnum.waiting }
|
||||
]
|
||||
: [])
|
||||
]
|
||||
values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])]
|
||||
});
|
||||
|
||||
q && generateVector();
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
|
@@ -1,30 +1,31 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import { getApiKey } from '@/service/utils/auth';
|
||||
import { getOpenAIApi } from '@/service/utils/chat/openai';
|
||||
import { embeddingModel } from '@/constants/model';
|
||||
import { axiosConfig } from '@/service/utils/tools';
|
||||
import { pushGenerateVectorBill } from '@/service/events/pushBill';
|
||||
import { ApiKeyType } from '@/service/utils/auth';
|
||||
|
||||
type Props = {
|
||||
input: string[];
|
||||
type?: ApiKeyType;
|
||||
};
|
||||
type Response = number[][];
|
||||
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { userId } = await authUser({ req });
|
||||
let { input } = req.query as Props;
|
||||
let { input, type } = req.query as Props;
|
||||
|
||||
if (!Array.isArray(input)) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
jsonRes<Response>(res, {
|
||||
data: await openaiEmbedding({ userId, input, mustPay: true })
|
||||
data: await openaiEmbedding({ userId, input, mustPay: true, type })
|
||||
});
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
@@ -38,12 +39,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
export async function openaiEmbedding({
|
||||
userId,
|
||||
input,
|
||||
mustPay = false
|
||||
mustPay = false,
|
||||
type = 'chat'
|
||||
}: { userId: string; mustPay?: boolean } & Props) {
|
||||
const { userOpenAiKey, systemAuthKey } = await getApiKey({
|
||||
model: 'gpt-3.5-turbo',
|
||||
userId,
|
||||
mustPay
|
||||
mustPay,
|
||||
type
|
||||
});
|
||||
|
||||
// 获取 chatAPI
|
||||
|
@@ -1,19 +0,0 @@
|
||||
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { generateQA } from '@/service/events/generateQA';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
generateQA();
|
||||
generateVector();
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
@@ -17,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
const { input } = req.body as TextPluginRequestParams;
|
||||
|
||||
const response = await axios({
|
||||
...axiosConfig(getSystemOpenAiKey()),
|
||||
...axiosConfig(getSystemOpenAiKey('chat')),
|
||||
method: 'POST',
|
||||
url: `/moderations`,
|
||||
data: {
|
||||
|
@@ -1,12 +1,11 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, SplitData } from '@/service/mongo';
|
||||
import { connectToDatabase, TrainingData } from '@/service/mongo';
|
||||
import { authKb, authUser } from '@/service/utils/auth';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
import { generateQA } from '@/service/events/generateQA';
|
||||
import { insertKbItem } from '@/service/pg';
|
||||
import { SplitTextTypEnum } from '@/constants/plugin';
|
||||
import { TrainingTypeEnum } from '@/constants/plugin';
|
||||
import { withNextCors } from '@/service/utils/tools';
|
||||
import { pushDataToKb } from '../kb/pushData';
|
||||
|
||||
/* split text */
|
||||
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
@@ -15,7 +14,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
kbId: string;
|
||||
chunks: string[];
|
||||
prompt: string;
|
||||
mode: `${SplitTextTypEnum}`;
|
||||
mode: `${TrainingTypeEnum}`;
|
||||
};
|
||||
if (!chunks || !kbId || !prompt) {
|
||||
throw new Error('参数错误');
|
||||
@@ -30,29 +29,26 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
||||
userId
|
||||
});
|
||||
|
||||
if (mode === SplitTextTypEnum.qa) {
|
||||
if (mode === TrainingTypeEnum.qa) {
|
||||
// 批量QA拆分插入数据
|
||||
await SplitData.create({
|
||||
const { _id } = await TrainingData.create({
|
||||
userId,
|
||||
kbId,
|
||||
textList: chunks,
|
||||
qaList: chunks,
|
||||
prompt
|
||||
});
|
||||
|
||||
generateQA();
|
||||
} else if (mode === SplitTextTypEnum.subsection) {
|
||||
// 待优化,直接调用另一个接口
|
||||
// 插入记录
|
||||
await insertKbItem({
|
||||
userId,
|
||||
generateQA(_id);
|
||||
} else if (mode === TrainingTypeEnum.subsection) {
|
||||
// 分段导入,直接插入向量队列
|
||||
const response = await pushDataToKb({
|
||||
kbId,
|
||||
data: chunks.map((item) => ({
|
||||
q: item,
|
||||
a: ''
|
||||
}))
|
||||
data: chunks.map((item) => ({ q: item, a: '' })),
|
||||
userId
|
||||
});
|
||||
|
||||
generateVector();
|
||||
return jsonRes(res, {
|
||||
data: response
|
||||
});
|
||||
}
|
||||
|
||||
jsonRes(res);
|
||||
|
@@ -1,14 +1,15 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
|
||||
import { connectToDatabase, TrainingData } from '@/service/mongo';
|
||||
import { authUser } from '@/service/utils/auth';
|
||||
import { ModelDataStatusEnum } from '@/constants/model';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { Types } from 'mongoose';
|
||||
import { generateQA } from '@/service/events/generateQA';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
|
||||
/* 拆分数据成QA */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
const { kbId } = req.query as { kbId: string };
|
||||
const { kbId, init = false } = req.body as { kbId: string; init: boolean };
|
||||
if (!kbId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
@@ -17,29 +18,43 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
const { userId } = await authUser({ req, authToken: true });
|
||||
|
||||
// split queue data
|
||||
const data = await SplitData.find({
|
||||
userId,
|
||||
kbId,
|
||||
textList: { $exists: true, $not: { $size: 0 } }
|
||||
});
|
||||
|
||||
// embedding queue data
|
||||
const embeddingData = await PgClient.count('modelData', {
|
||||
where: [
|
||||
['user_id', userId],
|
||||
'AND',
|
||||
['kb_id', kbId],
|
||||
'AND',
|
||||
['status', ModelDataStatusEnum.waiting]
|
||||
]
|
||||
});
|
||||
const result = await TrainingData.aggregate([
|
||||
{ $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } },
|
||||
{
|
||||
$project: {
|
||||
qaListLength: { $size: { $ifNull: ['$qaList', []] } },
|
||||
vectorListLength: { $size: { $ifNull: ['$vectorList', []] } }
|
||||
}
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: null,
|
||||
totalQaListLength: { $sum: '$qaListLength' },
|
||||
totalVectorListLength: { $sum: '$vectorListLength' }
|
||||
}
|
||||
}
|
||||
]);
|
||||
|
||||
jsonRes(res, {
|
||||
data: {
|
||||
splitDataQueue: data.map((item) => item.textList).flat().length,
|
||||
embeddingQueue: embeddingData
|
||||
qaListLen: result[0]?.totalQaListLength || 0,
|
||||
vectorListLen: result[0]?.totalVectorListLength || 0
|
||||
}
|
||||
});
|
||||
|
||||
if (init) {
|
||||
const list = await TrainingData.find(
|
||||
{
|
||||
userId,
|
||||
kbId
|
||||
},
|
||||
'_id'
|
||||
);
|
||||
list.forEach((item) => {
|
||||
generateQA(item._id);
|
||||
generateVector(item._id);
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
|
@@ -91,9 +91,9 @@ const DataCard = ({ kbId }: { kbId: string }) => {
|
||||
onClose: onCloseSelectCsvModal
|
||||
} = useDisclosure();
|
||||
|
||||
const { data: { splitDataQueue = 0, embeddingQueue = 0 } = {}, refetch } = useQuery(
|
||||
const { data: { qaListLen = 0, vectorListLen = 0 } = {}, refetch } = useQuery(
|
||||
['getModelSplitDataList'],
|
||||
() => getTrainingData(kbId),
|
||||
() => getTrainingData({ kbId, init: false }),
|
||||
{
|
||||
onError(err) {
|
||||
console.log(err);
|
||||
@@ -113,7 +113,7 @@ const DataCard = ({ kbId }: { kbId: string }) => {
|
||||
// interval get data
|
||||
useQuery(['refetchData'], () => refetchData(pageNum), {
|
||||
refetchInterval: 5000,
|
||||
enabled: splitDataQueue > 0 || embeddingQueue > 0
|
||||
enabled: qaListLen > 0 || vectorListLen > 0
|
||||
});
|
||||
|
||||
// get al data and export csv
|
||||
@@ -161,7 +161,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
|
||||
variant={'outline'}
|
||||
mr={[2, 4]}
|
||||
size={'sm'}
|
||||
onClick={() => refetchData(pageNum)}
|
||||
onClick={() => {
|
||||
refetchData(pageNum);
|
||||
getTrainingData({ kbId, init: true });
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
variant={'outline'}
|
||||
@@ -194,10 +197,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
|
||||
</Menu>
|
||||
</Flex>
|
||||
<Flex mt={4}>
|
||||
{(splitDataQueue > 0 || embeddingQueue > 0) && (
|
||||
{(qaListLen > 0 || vectorListLen > 0) && (
|
||||
<Box fontSize={'xs'}>
|
||||
{splitDataQueue > 0 ? `${splitDataQueue}条数据正在拆分,` : ''}
|
||||
{embeddingQueue > 0 ? `${embeddingQueue}条数据正在生成索引,` : ''}
|
||||
{qaListLen > 0 ? `${qaListLen}条数据正在拆分,` : ''}
|
||||
{vectorListLen > 0 ? `${vectorListLen}条数据正在生成索引,` : ''}
|
||||
请耐心等待...
|
||||
</Box>
|
||||
)}
|
||||
|
@@ -20,7 +20,8 @@ import { useMutation } from '@tanstack/react-query';
|
||||
import { postSplitData } from '@/api/plugins/kb';
|
||||
import Radio from '@/components/Radio';
|
||||
import { splitText_token } from '@/utils/file';
|
||||
import { SplitTextTypEnum } from '@/constants/plugin';
|
||||
import { TrainingTypeEnum } from '@/constants/plugin';
|
||||
import { getErrText } from '@/utils/tools';
|
||||
|
||||
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
|
||||
|
||||
@@ -52,7 +53,7 @@ const SelectFileModal = ({
|
||||
const { toast } = useToast();
|
||||
const [prompt, setPrompt] = useState('');
|
||||
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
|
||||
const [mode, setMode] = useState<`${SplitTextTypEnum}`>(SplitTextTypEnum.subsection);
|
||||
const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.subsection);
|
||||
const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
|
||||
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({
|
||||
tokens: 0,
|
||||
@@ -113,8 +114,9 @@ const SelectFileModal = ({
|
||||
prompt: `下面是"${prompt || '一段长文本'}"`,
|
||||
mode
|
||||
});
|
||||
|
||||
toast({
|
||||
title: '导入数据成功,需要一段拆解和训练',
|
||||
title: '导入数据成功,需要一段拆解和训练. 重复数据会自动删除',
|
||||
status: 'success'
|
||||
});
|
||||
onClose();
|
||||
@@ -130,27 +132,35 @@ const SelectFileModal = ({
|
||||
|
||||
const onclickImport = useCallback(async () => {
|
||||
setBtnLoading(true);
|
||||
let promise = Promise.resolve();
|
||||
try {
|
||||
let promise = Promise.resolve();
|
||||
|
||||
const splitRes = fileTextArr
|
||||
.filter((item) => item)
|
||||
.map((item) =>
|
||||
splitText_token({
|
||||
text: item,
|
||||
...modeMap[mode]
|
||||
})
|
||||
const splitRes = await Promise.all(
|
||||
fileTextArr
|
||||
.filter((item) => item)
|
||||
.map((item) =>
|
||||
splitText_token({
|
||||
text: item,
|
||||
...modeMap[mode]
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
setSplitRes({
|
||||
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
|
||||
chunks: splitRes.map((item) => item.chunks).flat()
|
||||
});
|
||||
setSplitRes({
|
||||
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
|
||||
chunks: splitRes.map((item) => item.chunks).flat()
|
||||
});
|
||||
|
||||
await promise;
|
||||
openConfirm(mutate)();
|
||||
} catch (error) {
|
||||
toast({
|
||||
status: 'warning',
|
||||
title: getErrText(error, '拆分文本异常')
|
||||
});
|
||||
}
|
||||
setBtnLoading(false);
|
||||
|
||||
await promise;
|
||||
openConfirm(mutate)();
|
||||
}, [fileTextArr, mode, mutate, openConfirm]);
|
||||
}, [fileTextArr, mode, mutate, openConfirm, toast]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={true} onClose={onClose} isCentered>
|
||||
|
@@ -53,10 +53,11 @@ function responseError(err: any) {
|
||||
}
|
||||
|
||||
/* 创建请求实例 */
|
||||
const instance = axios.create({
|
||||
export const instance = axios.create({
|
||||
timeout: 60000, // 超时时间
|
||||
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
|
||||
headers: {
|
||||
'content-type': 'application/json'
|
||||
rootkey: process.env.ROOT_KEY
|
||||
}
|
||||
});
|
||||
|
||||
@@ -75,7 +76,6 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
|
||||
|
||||
return instance
|
||||
.request({
|
||||
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
|
||||
url,
|
||||
method,
|
||||
data: method === 'GET' ? null : data,
|
||||
@@ -93,18 +93,30 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
|
||||
* @param {Object} config
|
||||
* @returns
|
||||
*/
|
||||
export function GET<T>(url: string, params = {}, config: ConfigType = {}): Promise<T> {
|
||||
export function GET<T = { data: any }>(
|
||||
url: string,
|
||||
params = {},
|
||||
config: ConfigType = {}
|
||||
): Promise<T> {
|
||||
return request(url, params, config, 'GET');
|
||||
}
|
||||
|
||||
export function POST<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> {
|
||||
export function POST<T = { data: any }>(
|
||||
url: string,
|
||||
data = {},
|
||||
config: ConfigType = {}
|
||||
): Promise<T> {
|
||||
return request(url, data, config, 'POST');
|
||||
}
|
||||
|
||||
export function PUT<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> {
|
||||
export function PUT<T = { data: any }>(
|
||||
url: string,
|
||||
data = {},
|
||||
config: ConfigType = {}
|
||||
): Promise<T> {
|
||||
return request(url, data, config, 'PUT');
|
||||
}
|
||||
|
||||
export function DELETE<T>(url: string, config: ConfigType = {}): Promise<T> {
|
||||
export function DELETE<T = { data: any }>(url: string, config: ConfigType = {}): Promise<T> {
|
||||
return request(url, {}, config, 'DELETE');
|
||||
}
|
||||
|
@@ -1,75 +1,55 @@
|
||||
import { SplitData } from '@/service/mongo';
|
||||
import { TrainingData } from '@/service/mongo';
|
||||
import { getApiKey } from '../utils/auth';
|
||||
import { OpenAiChatEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
import { generateVector } from './generateVector';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
import { insertKbItem } from '@/service/pg';
|
||||
import { SplitDataSchema } from '@/types/mongoSchema';
|
||||
import { modelServiceToolMap } from '../utils/chat';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { getErrText } from '@/utils/tools';
|
||||
import { BillTypeEnum } from '@/constants/user';
|
||||
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
|
||||
export async function generateQA(next = false): Promise<any> {
|
||||
if (process.env.queueTask !== '1') {
|
||||
try {
|
||||
fetch(process.env.parentUrl || '');
|
||||
} catch (error) {
|
||||
console.log('parentUrl fetch error', error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (global.generatingQA === true && !next) return;
|
||||
|
||||
global.generatingQA = true;
|
||||
|
||||
let dataId = null;
|
||||
// 每次最多选 1 组
|
||||
const listLen = 1;
|
||||
|
||||
export async function generateQA(trainingId: string): Promise<any> {
|
||||
try {
|
||||
// 找出一个需要生成的 dataItem
|
||||
const data = await SplitData.aggregate([
|
||||
{ $match: { textList: { $exists: true, $ne: [] } } },
|
||||
{ $sample: { size: 1 } }
|
||||
]);
|
||||
// 找出一个需要生成的 dataItem (4分钟锁)
|
||||
const data = await TrainingData.findOneAndUpdate(
|
||||
{
|
||||
_id: trainingId,
|
||||
lockTime: { $lte: Date.now() - 4 * 60 * 1000 }
|
||||
},
|
||||
{
|
||||
lockTime: new Date()
|
||||
}
|
||||
);
|
||||
|
||||
const dataItem: SplitDataSchema = data[0];
|
||||
|
||||
if (!dataItem) {
|
||||
console.log('没有需要生成 QA 的数据');
|
||||
global.generatingQA = false;
|
||||
return;
|
||||
}
|
||||
|
||||
dataId = dataItem._id;
|
||||
|
||||
// 获取 5 个源文本
|
||||
const textList: string[] = dataItem.textList.slice(-5);
|
||||
|
||||
// 获取 openapi Key
|
||||
let userOpenAiKey = '',
|
||||
systemAuthKey = '';
|
||||
try {
|
||||
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
|
||||
userOpenAiKey = key.userOpenAiKey;
|
||||
systemAuthKey = key.systemAuthKey;
|
||||
} catch (err: any) {
|
||||
// 余额不够了, 清空该记录
|
||||
await SplitData.findByIdAndUpdate(dataItem._id, {
|
||||
textList: [],
|
||||
errorText: getErrText(err, '获取 OpenAi Key 失败')
|
||||
if (!data || data.qaList.length === 0) {
|
||||
await TrainingData.findOneAndDelete({
|
||||
_id: trainingId,
|
||||
qaList: [],
|
||||
vectorList: []
|
||||
});
|
||||
generateQA(true);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`正在生成一组QA, 包含 ${textList.length} 组文本。ID: ${dataItem._id}`);
|
||||
const qaList: string[] = data.qaList.slice(-listLen);
|
||||
|
||||
// 余额校验并获取 openapi Key
|
||||
const { userOpenAiKey, systemAuthKey } = await getApiKey({
|
||||
model: OpenAiChatEnum.GPT35,
|
||||
userId: data.userId,
|
||||
type: 'training'
|
||||
});
|
||||
|
||||
console.log(`正在生成一组QA, 包含 ${qaList.length} 组文本。ID: ${data._id}`);
|
||||
|
||||
const startTime = Date.now();
|
||||
|
||||
// 请求 chatgpt 获取回答
|
||||
const response = await Promise.allSettled(
|
||||
textList.map((text) =>
|
||||
const response = await Promise.all(
|
||||
qaList.map((text) =>
|
||||
modelServiceToolMap[OpenAiChatEnum.GPT35]
|
||||
.chatCompletion({
|
||||
apiKey: userOpenAiKey || systemAuthKey,
|
||||
@@ -78,7 +58,7 @@ export async function generateQA(next = false): Promise<any> {
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `你是出题人
|
||||
${dataItem.prompt || '下面是"一段长文本"'}
|
||||
${data.prompt || '下面是"一段长文本"'}
|
||||
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
|
||||
A1:
|
||||
Q2:
|
||||
@@ -98,7 +78,7 @@ A2:
|
||||
// 计费
|
||||
pushSplitDataBill({
|
||||
isPay: !userOpenAiKey && result.length > 0,
|
||||
userId: dataItem.userId,
|
||||
userId: data.userId,
|
||||
type: BillTypeEnum.QA,
|
||||
textLen: responseMessages.map((item) => item.value).join('').length,
|
||||
totalTokens
|
||||
@@ -116,57 +96,59 @@ A2:
|
||||
)
|
||||
);
|
||||
|
||||
// 获取成功的回答
|
||||
const successResponse: {
|
||||
rawContent: string;
|
||||
result: {
|
||||
q: string;
|
||||
a: string;
|
||||
}[];
|
||||
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
|
||||
const responseList = response.map((item) => item.result).flat();
|
||||
|
||||
const resultList = successResponse.map((item) => item.result).flat();
|
||||
// 创建 向量生成 队列
|
||||
pushDataToKb({
|
||||
kbId: data.kbId,
|
||||
data: responseList,
|
||||
userId: data.userId
|
||||
});
|
||||
|
||||
await Promise.allSettled([
|
||||
// 删掉后5个数据
|
||||
SplitData.findByIdAndUpdate(dataItem._id, {
|
||||
textList: dataItem.textList.slice(0, -5)
|
||||
}),
|
||||
// 生成的内容插入 pg
|
||||
insertKbItem({
|
||||
userId: dataItem.userId,
|
||||
kbId: dataItem.kbId,
|
||||
data: resultList
|
||||
})
|
||||
]);
|
||||
console.log('生成QA成功,time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
generateQA(true);
|
||||
generateVector();
|
||||
} catch (error: any) {
|
||||
// log
|
||||
if (error?.response) {
|
||||
console.log('openai error: 生成QA错误');
|
||||
console.log(error.response?.status, error.response?.statusText, error.response?.data);
|
||||
// 删除 QA 队列。如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
|
||||
if (data.vectorList.length <= listLen) {
|
||||
await TrainingData.findByIdAndDelete(data._id);
|
||||
} else {
|
||||
console.log('生成QA错误:', error);
|
||||
await TrainingData.findByIdAndUpdate(data._id, {
|
||||
qaList: data.qaList.slice(0, -listLen),
|
||||
lockTime: new Date('2000/1/1')
|
||||
});
|
||||
}
|
||||
|
||||
// 没有余额或者凭证错误时,拒绝任务
|
||||
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
|
||||
console.log(openaiError2[error?.response?.data?.error?.type], '删除QA任务');
|
||||
console.log('生成QA成功,time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
await SplitData.findByIdAndUpdate(dataId, {
|
||||
textList: [],
|
||||
errorText: 'api 余额不足'
|
||||
});
|
||||
generateQA(trainingId);
|
||||
} catch (err: any) {
|
||||
// log
|
||||
if (err?.response) {
|
||||
console.log('openai error: 生成QA错误');
|
||||
console.log(err.response?.status, err.response?.statusText, err.response?.data);
|
||||
} else {
|
||||
console.log('生成QA错误:', err);
|
||||
}
|
||||
|
||||
generateQA(true);
|
||||
// openai 账号异常或者账号余额不足,删除任务
|
||||
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
|
||||
console.log('余额不足,删除向量生成任务');
|
||||
await TrainingData.findByIdAndDelete(trainingId);
|
||||
return;
|
||||
}
|
||||
|
||||
// unlock
|
||||
await TrainingData.findByIdAndUpdate(trainingId, {
|
||||
lockTime: new Date('2000/1/1')
|
||||
});
|
||||
|
||||
// 频率限制
|
||||
if (err?.response?.statusText === 'Too Many Requests') {
|
||||
console.log('生成向量次数限制,30s后尝试');
|
||||
return setTimeout(() => {
|
||||
generateQA(trainingId);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
generateQA(true);
|
||||
generateQA(trainingId);
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
@@ -1,107 +1,137 @@
|
||||
import { getApiKey } from '../utils/auth';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { getErrText } from '@/utils/tools';
|
||||
import { insertKbItem, PgClient } from '@/service/pg';
|
||||
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
|
||||
import { TrainingData } from '../models/trainingData';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
|
||||
export async function generateVector(next = false): Promise<any> {
|
||||
if (process.env.queueTask !== '1') {
|
||||
try {
|
||||
fetch(process.env.parentUrl || '');
|
||||
} catch (error) {
|
||||
console.log('parentUrl fetch error', error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (global.generatingVector && !next) return;
|
||||
|
||||
global.generatingVector = true;
|
||||
let dataId = null;
|
||||
// 每次最多选 5 组
|
||||
const listLen = 5;
|
||||
|
||||
/* 索引生成队列。每导入一次,就是一个单独的线程 */
|
||||
export async function generateVector(trainingId: string): Promise<any> {
|
||||
try {
|
||||
// 从找出一个 status = waiting 的数据
|
||||
const searchRes = await PgClient.select('modelData', {
|
||||
fields: ['id', 'q', 'user_id'],
|
||||
where: [['status', 'waiting']],
|
||||
limit: 1
|
||||
});
|
||||
|
||||
if (searchRes.rowCount === 0) {
|
||||
console.log('没有需要生成 【向量】 的数据');
|
||||
global.generatingVector = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const dataItem: { id: string; q: string; userId: string } = {
|
||||
id: searchRes.rows[0].id,
|
||||
q: searchRes.rows[0].q,
|
||||
userId: searchRes.rows[0].user_id
|
||||
};
|
||||
|
||||
dataId = dataItem.id;
|
||||
|
||||
// 获取 openapi Key
|
||||
try {
|
||||
await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
|
||||
} catch (err: any) {
|
||||
await PgClient.delete('modelData', {
|
||||
where: [['id', dataId]]
|
||||
});
|
||||
getErrText(err, '获取 OpenAi Key 失败');
|
||||
return generateVector(true);
|
||||
}
|
||||
|
||||
// 生成词向量
|
||||
const vectors = await openaiEmbedding({
|
||||
input: [dataItem.q],
|
||||
userId: dataItem.userId
|
||||
});
|
||||
|
||||
// 更新 pg 向量和状态数据
|
||||
await PgClient.update('modelData', {
|
||||
values: [
|
||||
{ key: 'vector', value: `[${vectors[0]}]` },
|
||||
{ key: 'status', value: `ready` }
|
||||
],
|
||||
where: [['id', dataId]]
|
||||
});
|
||||
|
||||
console.log(`生成向量成功: ${dataItem.id}`);
|
||||
|
||||
generateVector(true);
|
||||
} catch (error: any) {
|
||||
// log
|
||||
if (error?.response) {
|
||||
console.log('openai error: 生成向量错误');
|
||||
console.log(error.response?.status, error.response?.statusText, error.response?.data);
|
||||
} else {
|
||||
console.log('生成向量错误:', error);
|
||||
}
|
||||
|
||||
// 没有余额或者凭证错误时,拒绝任务
|
||||
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
|
||||
console.log('删除向量生成任务记录');
|
||||
try {
|
||||
await PgClient.delete('modelData', {
|
||||
where: [['id', dataId]]
|
||||
});
|
||||
} catch (error) {
|
||||
error;
|
||||
// 找出一个需要生成的 dataItem (2分钟锁)
|
||||
const data = await TrainingData.findOneAndUpdate(
|
||||
{
|
||||
_id: trainingId,
|
||||
lockTime: { $lte: Date.now() - 2 * 60 * 1000 }
|
||||
},
|
||||
{
|
||||
lockTime: new Date()
|
||||
}
|
||||
generateVector(true);
|
||||
);
|
||||
|
||||
if (!data) {
|
||||
await TrainingData.findOneAndDelete({
|
||||
_id: trainingId,
|
||||
qaList: [],
|
||||
vectorList: []
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (error?.response?.statusText === 'Too Many Requests') {
|
||||
console.log('生成向量次数限制,1分钟后尝试');
|
||||
// 限制次数,1分钟后再试
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
}, 60000);
|
||||
|
||||
const userId = String(data.userId);
|
||||
const kbId = String(data.kbId);
|
||||
|
||||
const dataItems: { q: string; a: string }[] = data.vectorList.slice(-listLen).map((item) => ({
|
||||
q: item.q,
|
||||
a: item.a
|
||||
}));
|
||||
|
||||
// 过滤重复的 qa 内容
|
||||
const searchRes = await Promise.allSettled(
|
||||
dataItems.map(async ({ q, a = '' }) => {
|
||||
if (!q) {
|
||||
return Promise.reject('q为空');
|
||||
}
|
||||
|
||||
q = q.replace(/\\n/g, '\n');
|
||||
a = a.replace(/\\n/g, '\n');
|
||||
|
||||
// Exactly the same data, not push
|
||||
try {
|
||||
const count = await PgClient.count('modelData', {
|
||||
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
|
||||
});
|
||||
if (count > 0) {
|
||||
return Promise.reject('已经存在');
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
return Promise.resolve({
|
||||
q,
|
||||
a
|
||||
});
|
||||
})
|
||||
);
|
||||
const filterData = searchRes
|
||||
.filter((item) => item.status === 'fulfilled')
|
||||
.map<{ q: string; a: string }>((item: any) => item.value);
|
||||
|
||||
if (filterData.length > 0) {
|
||||
// 生成词向量
|
||||
const vectors = await openaiEmbedding({
|
||||
input: filterData.map((item) => item.q),
|
||||
userId,
|
||||
type: 'training'
|
||||
});
|
||||
|
||||
// 生成结果插入到 pg
|
||||
await insertKbItem({
|
||||
userId,
|
||||
kbId,
|
||||
data: vectors.map((vector, i) => ({
|
||||
q: filterData[i].q,
|
||||
a: filterData[i].a,
|
||||
vector
|
||||
}))
|
||||
});
|
||||
}
|
||||
|
||||
// 删除 mongo 训练队列. 如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
|
||||
if (data.vectorList.length <= listLen) {
|
||||
await TrainingData.findByIdAndDelete(trainingId);
|
||||
console.log(`全部向量生成完毕: ${trainingId}`);
|
||||
} else {
|
||||
await TrainingData.findByIdAndUpdate(trainingId, {
|
||||
vectorList: data.vectorList.slice(0, -listLen),
|
||||
lockTime: new Date('2000/1/1')
|
||||
});
|
||||
console.log(`生成向量成功: ${trainingId}`);
|
||||
generateVector(trainingId);
|
||||
}
|
||||
} catch (err: any) {
|
||||
// log
|
||||
if (err?.response) {
|
||||
console.log('openai error: 生成向量错误');
|
||||
console.log(err.response?.status, err.response?.statusText, err.response?.data);
|
||||
} else {
|
||||
console.log('生成向量错误:', err);
|
||||
}
|
||||
|
||||
// openai 账号异常或者账号余额不足,删除任务
|
||||
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
|
||||
console.log('余额不足,删除向量生成任务');
|
||||
await TrainingData.findByIdAndDelete(trainingId);
|
||||
return;
|
||||
}
|
||||
|
||||
// unlock
|
||||
await TrainingData.findByIdAndUpdate(trainingId, {
|
||||
lockTime: new Date('2000/1/1')
|
||||
});
|
||||
|
||||
// 频率限制
|
||||
if (err?.response?.statusText === 'Too Many Requests') {
|
||||
console.log('生成向量次数限制,30s后尝试');
|
||||
return setTimeout(() => {
|
||||
generateVector(trainingId);
|
||||
}, 30000);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
generateVector(trainingId);
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
@@ -2,7 +2,7 @@ import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { ModelSchema as ModelType } from '@/types/mongoSchema';
|
||||
import {
|
||||
ModelVectorSearchModeMap,
|
||||
ModelVectorSearchModeEnum,
|
||||
appVectorSearchModeEnum,
|
||||
ChatModelMap,
|
||||
OpenAiChatEnum
|
||||
} from '@/constants/model';
|
||||
@@ -40,7 +40,7 @@ const ModelSchema = new Schema({
|
||||
// knowledge base search mode
|
||||
type: String,
|
||||
enum: Object.keys(ModelVectorSearchModeMap),
|
||||
default: ModelVectorSearchModeEnum.hightSimilarity
|
||||
default: appVectorSearchModeEnum.hightSimilarity
|
||||
},
|
||||
systemPrompt: {
|
||||
// 系统提示词
|
||||
|
@@ -1,32 +0,0 @@
|
||||
/* 模型的知识库 */
|
||||
import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { SplitDataSchema as SplitDataType } from '@/types/mongoSchema';
|
||||
|
||||
const SplitDataSchema = new Schema({
|
||||
userId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
prompt: {
|
||||
// 拆分时的提示词
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
kbId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'kb',
|
||||
required: true
|
||||
},
|
||||
textList: {
|
||||
type: [String],
|
||||
default: []
|
||||
},
|
||||
errorText: {
|
||||
type: String,
|
||||
default: ''
|
||||
}
|
||||
});
|
||||
|
||||
export const SplitData: MongoModel<SplitDataType> =
|
||||
models['splitData'] || model('splitData', SplitDataSchema);
|
38
src/service/models/trainingData.ts
Normal file
38
src/service/models/trainingData.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/* 模型的知识库 */
|
||||
import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema';
|
||||
|
||||
// pgList and vectorList, Only one of them will work
|
||||
|
||||
const TrainingDataSchema = new Schema({
|
||||
userId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
kbId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'kb',
|
||||
required: true
|
||||
},
|
||||
lockTime: {
|
||||
type: Date,
|
||||
default: () => new Date('2000/1/1')
|
||||
},
|
||||
vectorList: {
|
||||
type: [{ q: String, a: String }],
|
||||
default: []
|
||||
},
|
||||
prompt: {
|
||||
// 拆分时的提示词
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
qaList: {
|
||||
type: [String],
|
||||
default: []
|
||||
}
|
||||
});
|
||||
|
||||
export const TrainingData: MongoModel<TrainingDateType> =
|
||||
models['trainingData'] || model('trainingData', TrainingDataSchema);
|
@@ -2,6 +2,7 @@ import mongoose from 'mongoose';
|
||||
import { generateQA } from './events/generateQA';
|
||||
import { generateVector } from './events/generateVector';
|
||||
import tunnel from 'tunnel';
|
||||
import { TrainingData } from './mongo';
|
||||
|
||||
/**
|
||||
* 连接 MongoDB 数据库
|
||||
@@ -27,9 +28,6 @@ export async function connectToDatabase(): Promise<void> {
|
||||
global.mongodb = null;
|
||||
}
|
||||
|
||||
generateQA();
|
||||
generateVector();
|
||||
|
||||
// 创建代理对象
|
||||
if (process.env.AXIOS_PROXY_HOST && process.env.AXIOS_PROXY_PORT) {
|
||||
global.httpsAgent = tunnel.httpsOverHttp({
|
||||
@@ -39,6 +37,34 @@ export async function connectToDatabase(): Promise<void> {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
startTrain();
|
||||
// 5 分钟后解锁不正常的数据,并触发开始训练
|
||||
setTimeout(async () => {
|
||||
await TrainingData.updateMany(
|
||||
{
|
||||
lockTime: { $lte: Date.now() - 5 * 60 * 1000 }
|
||||
},
|
||||
{
|
||||
lockTime: new Date('2000/1/1')
|
||||
}
|
||||
);
|
||||
startTrain();
|
||||
}, 5 * 60 * 1000);
|
||||
}
|
||||
|
||||
async function startTrain() {
|
||||
const qa = await TrainingData.find({
|
||||
qaList: { $exists: true, $ne: [] }
|
||||
});
|
||||
|
||||
qa.map((item) => generateQA(String(item._id)));
|
||||
|
||||
const vector = await TrainingData.find({
|
||||
vectorList: { $exists: true, $ne: [] }
|
||||
});
|
||||
|
||||
vector.map((item) => generateVector(String(item._id)));
|
||||
}
|
||||
|
||||
export * from './models/authCode';
|
||||
@@ -47,7 +73,7 @@ export * from './models/model';
|
||||
export * from './models/user';
|
||||
export * from './models/bill';
|
||||
export * from './models/pay';
|
||||
export * from './models/splitData';
|
||||
export * from './models/trainingData';
|
||||
export * from './models/openapi';
|
||||
export * from './models/promotionRecord';
|
||||
export * from './models/collection';
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import { Pool } from 'pg';
|
||||
import type { QueryResultRow } from 'pg';
|
||||
import { ModelDataStatusEnum } from '@/constants/model';
|
||||
|
||||
export const connectPg = async () => {
|
||||
if (global.pgClient) {
|
||||
@@ -168,6 +169,7 @@ export const insertKbItem = ({
|
||||
userId: string;
|
||||
kbId: string;
|
||||
data: {
|
||||
vector: number[];
|
||||
q: string;
|
||||
a: string;
|
||||
}[];
|
||||
@@ -178,7 +180,8 @@ export const insertKbItem = ({
|
||||
{ key: 'kb_id', value: kbId },
|
||||
{ key: 'q', value: item.q },
|
||||
{ key: 'a', value: item.a },
|
||||
{ key: 'status', value: 'waiting' }
|
||||
{ key: 'vector', value: `[${item.vector}]` },
|
||||
{ key: 'status', value: ModelDataStatusEnum.ready }
|
||||
])
|
||||
});
|
||||
};
|
||||
|
@@ -5,12 +5,14 @@ import { Chat, Model, OpenApi, User, ShareChat, KB } from '../mongo';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import mongoose from 'mongoose';
|
||||
import { ClaudeEnum, defaultModel } from '@/constants/model';
|
||||
import { ClaudeEnum, defaultModel, embeddingModel, EmbeddingModelType } from '@/constants/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
import { ChatModelType, OpenAiChatEnum } from '@/constants/model';
|
||||
import { hashPassword } from '@/service/utils/tools';
|
||||
|
||||
export type ApiKeyType = 'training' | 'chat';
|
||||
|
||||
export const parseCookie = (cookie?: string): Promise<string> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
// 获取 cookie
|
||||
@@ -118,9 +120,15 @@ export const authUser = async ({
|
||||
};
|
||||
|
||||
/* random get openai api key */
|
||||
export const getSystemOpenAiKey = () => {
|
||||
export const getSystemOpenAiKey = (type: ApiKeyType) => {
|
||||
const keys = (() => {
|
||||
if (type === 'training') {
|
||||
return process.env.OPENAI_TRAINING_KEY?.split(',') || [];
|
||||
}
|
||||
return process.env.OPENAIKEY?.split(',') || [];
|
||||
})();
|
||||
|
||||
// 纯字符串类型
|
||||
const keys = process.env.OPENAIKEY?.split(',') || [];
|
||||
const i = Math.floor(Math.random() * keys.length);
|
||||
return keys[i] || (process.env.OPENAIKEY as string);
|
||||
};
|
||||
@@ -129,11 +137,13 @@ export const getSystemOpenAiKey = () => {
|
||||
export const getApiKey = async ({
|
||||
model,
|
||||
userId,
|
||||
mustPay = false
|
||||
mustPay = false,
|
||||
type = 'chat'
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
userId: string;
|
||||
mustPay?: boolean;
|
||||
type?: ApiKeyType;
|
||||
}) => {
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
@@ -143,7 +153,7 @@ export const getApiKey = async ({
|
||||
const keyMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
userOpenAiKey: user.openaiKey || '',
|
||||
systemAuthKey: getSystemOpenAiKey() as string
|
||||
systemAuthKey: getSystemOpenAiKey(type) as string
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
userOpenAiKey: user.openaiKey || '',
|
||||
|
2
src/types/index.d.ts
vendored
2
src/types/index.d.ts
vendored
@@ -5,8 +5,6 @@ import type { Pool } from 'pg';
|
||||
declare global {
|
||||
var mongodb: Mongoose | string | null;
|
||||
var pgClient: Pool | null;
|
||||
var generatingQA: boolean;
|
||||
var generatingVector: boolean;
|
||||
var httpsAgent: Agent;
|
||||
var particlesJS: any;
|
||||
var grecaptcha: any;
|
||||
|
2
src/types/model.d.ts
vendored
2
src/types/model.d.ts
vendored
@@ -1,6 +1,6 @@
|
||||
import { ModelStatusEnum } from '@/constants/model';
|
||||
import type { ModelSchema, kbSchema } from './mongoSchema';
|
||||
import { ChatModelType, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||
import { ChatModelType, appVectorSearchModeEnum } from '@/constants/model';
|
||||
|
||||
export type ModelListItemType = {
|
||||
_id: string;
|
||||
|
12
src/types/mongoSchema.d.ts
vendored
12
src/types/mongoSchema.d.ts
vendored
@@ -2,12 +2,13 @@ import type { ChatItemType } from './chat';
|
||||
import {
|
||||
ModelStatusEnum,
|
||||
ModelNameEnum,
|
||||
ModelVectorSearchModeEnum,
|
||||
appVectorSearchModeEnum,
|
||||
ChatModelType,
|
||||
EmbeddingModelType
|
||||
} from '@/constants/model';
|
||||
import type { DataType } from './data';
|
||||
import { BillTypeEnum } from '@/constants/user';
|
||||
import { TrainingTypeEnum } from '@/constants/plugin';
|
||||
|
||||
export interface UserModelSchema {
|
||||
_id: string;
|
||||
@@ -44,7 +45,7 @@ export interface ModelSchema {
|
||||
updateTime: number;
|
||||
chat: {
|
||||
relatedKbs: string[];
|
||||
searchMode: `${ModelVectorSearchModeEnum}`;
|
||||
searchMode: `${appVectorSearchModeEnum}`;
|
||||
systemPrompt: string;
|
||||
temperature: number;
|
||||
chatModel: ChatModelType; // 聊天时用的模型,训练后就是训练的模型
|
||||
@@ -68,13 +69,14 @@ export interface CollectionSchema {
|
||||
|
||||
export type ModelDataType = 0 | 1;
|
||||
|
||||
export interface SplitDataSchema {
|
||||
export interface TrainingDataSchema {
|
||||
_id: string;
|
||||
userId: string;
|
||||
kbId: string;
|
||||
lockTime: Date;
|
||||
vectorList: { q: string; a: string }[];
|
||||
prompt: string;
|
||||
errorText: string;
|
||||
textList: string[];
|
||||
qaList: string[];
|
||||
}
|
||||
|
||||
export interface ChatSchema {
|
||||
|
15
src/types/plugin.d.ts
vendored
15
src/types/plugin.d.ts
vendored
@@ -1,5 +1,4 @@
|
||||
import type { kbSchema } from './mongoSchema';
|
||||
import { PluginTypeEnum } from '@/constants/plugin';
|
||||
|
||||
/* kb type */
|
||||
export interface KbItemType extends kbSchema {
|
||||
@@ -16,20 +15,6 @@ export interface KbDataItemType {
|
||||
userId: string;
|
||||
}
|
||||
|
||||
/* plugin */
|
||||
export interface PluginConfig {
|
||||
name: string;
|
||||
desc: string;
|
||||
url: string;
|
||||
category: `${PluginTypeEnum}`;
|
||||
uniPrice: 22; // 1k token
|
||||
params: [
|
||||
{
|
||||
type: '';
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
export type TextPluginRequestParams = {
|
||||
input: string;
|
||||
};
|
||||
|
@@ -145,7 +145,7 @@ export const fileDownload = ({
|
||||
* slideLen - The size of the before and after Text
|
||||
* maxLen > slideLen
|
||||
*/
|
||||
export const splitText_token = ({
|
||||
export const splitText_token = async ({
|
||||
text,
|
||||
maxLen,
|
||||
slideLen
|
||||
@@ -154,32 +154,39 @@ export const splitText_token = ({
|
||||
maxLen: number;
|
||||
slideLen: number;
|
||||
}) => {
|
||||
const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
|
||||
// filter empty text. encode sentence
|
||||
const encodeText = enc.encode(text);
|
||||
try {
|
||||
const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
|
||||
// filter empty text. encode sentence
|
||||
const encodeText = enc.encode(text);
|
||||
|
||||
const chunks: string[] = [];
|
||||
let tokens = 0;
|
||||
const chunks: string[] = [];
|
||||
let tokens = 0;
|
||||
|
||||
let startIndex = 0;
|
||||
let endIndex = Math.min(startIndex + maxLen, encodeText.length);
|
||||
let chunkEncodeArr = encodeText.slice(startIndex, endIndex);
|
||||
let startIndex = 0;
|
||||
let endIndex = Math.min(startIndex + maxLen, encodeText.length);
|
||||
let chunkEncodeArr = encodeText.slice(startIndex, endIndex);
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (startIndex < encodeText.length) {
|
||||
tokens += chunkEncodeArr.length;
|
||||
chunks.push(decoder.decode(enc.decode(chunkEncodeArr)));
|
||||
while (startIndex < encodeText.length) {
|
||||
tokens += chunkEncodeArr.length;
|
||||
chunks.push(decoder.decode(enc.decode(chunkEncodeArr)));
|
||||
|
||||
startIndex += maxLen - slideLen;
|
||||
endIndex = Math.min(startIndex + maxLen, encodeText.length);
|
||||
chunkEncodeArr = encodeText.slice(Math.min(encodeText.length - slideLen, startIndex), endIndex);
|
||||
startIndex += maxLen - slideLen;
|
||||
endIndex = Math.min(startIndex + maxLen, encodeText.length);
|
||||
chunkEncodeArr = encodeText.slice(
|
||||
Math.min(encodeText.length - slideLen, startIndex),
|
||||
endIndex
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
chunks,
|
||||
tokens
|
||||
};
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
return {
|
||||
chunks,
|
||||
tokens
|
||||
};
|
||||
};
|
||||
|
||||
export const fileToBase64 = (file: File) => {
|
||||
|
@@ -26,7 +26,7 @@ export const authGoogleToken = async (data: {
|
||||
const res = await axios.post<{ score?: number }>(
|
||||
`https://www.recaptcha.net/recaptcha/api/siteverify?${Obj2Query(data)}`
|
||||
);
|
||||
if (res.data.score && res.data.score >= 0.9) {
|
||||
if (res.data.score && res.data.score >= 0.5) {
|
||||
return Promise.resolve('');
|
||||
}
|
||||
return Promise.reject('非法环境');
|
||||
|
Reference in New Issue
Block a user