training queue

This commit is contained in:
archer
2023-05-26 23:08:25 +08:00
parent 69f32a0861
commit dc1c1d1355
32 changed files with 528 additions and 493 deletions

View File

@@ -1,9 +1,6 @@
# proxy # proxy
# AXIOS_PROXY_HOST=127.0.0.1 # AXIOS_PROXY_HOST=127.0.0.1
# AXIOS_PROXY_PORT=7890 # AXIOS_PROXY_PORT=7890
# 是否开启队列任务。 1-开启0-关闭请求parentUrl去执行任务,单机时直接填1
queueTask=1
parentUrl=https://hostname/api/openapi/startEvents
# email # email
MY_MAIL=xxx@qq.com MY_MAIL=xxx@qq.com
MAILE_CODE=xxx MAILE_CODE=xxx
@@ -21,7 +18,8 @@ SENSITIVE_CHECK=1
# openai # openai
# OPENAI_BASE_URL=https://api.openai.com/v1 # OPENAI_BASE_URL=https://api.openai.com/v1
# OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉) # OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉)
OPENAIKEY=sk-xxx OPENAIKEY=sk-xxx # 对话用的key
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
GPT4KEY=sk-xxx GPT4KEY=sk-xxx
# claude # claude
CLAUDE_BASE_URL=calude模型请求地址 CLAUDE_BASE_URL=calude模型请求地址

View File

@@ -39,9 +39,6 @@ services:
# proxy可选 # proxy可选
- AXIOS_PROXY_HOST=127.0.0.1 - AXIOS_PROXY_HOST=127.0.0.1
- AXIOS_PROXY_PORT=7890 - AXIOS_PROXY_PORT=7890
# 是否开启队列任务。 1-开启0-关闭(请求 parentUrl 去执行任务,单机时直接填1
- queueTask=1
- parentUrl=https://hostname/api/openapi/startEvents
# 发送邮箱验证码配置。用的是QQ邮箱。参考 nodeMail 获取MAILE_CODE自行百度。 # 发送邮箱验证码配置。用的是QQ邮箱。参考 nodeMail 获取MAILE_CODE自行百度。
- MY_MAIL=xxxx@qq.com - MY_MAIL=xxxx@qq.com
- MAILE_CODE=xxxx - MAILE_CODE=xxxx
@@ -66,7 +63,8 @@ services:
- PG_PASSWORD=1234 # POSTGRES_PASSWORD - PG_PASSWORD=1234 # POSTGRES_PASSWORD
- PG_DB_NAME=fastgpt # POSTGRES_DB - PG_DB_NAME=fastgpt # POSTGRES_DB
# openai # openai
- OPENAIKEY=sk-xxxxx - OPENAIKEY=sk-xxxxx # 对话用的key
- OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
- GPT4KEY=sk-xxx - GPT4KEY=sk-xxx
- OPENAI_BASE_URL=https://api.openai.com/v1 - OPENAI_BASE_URL=https://api.openai.com/v1
- OPENAI_BASE_URL_AUTH=可选的安全凭证 - OPENAI_BASE_URL_AUTH=可选的安全凭证

View File

@@ -36,7 +36,6 @@ mongo pg
AXIOS_PROXY_HOST=127.0.0.1 AXIOS_PROXY_HOST=127.0.0.1
AXIOS_PROXY_PORT_FAST=7890 AXIOS_PROXY_PORT_FAST=7890
AXIOS_PROXY_PORT_NORMAL=7890 AXIOS_PROXY_PORT_NORMAL=7890
queueTask=1
# email # email
MY_MAIL= {Your Mail} MY_MAIL= {Your Mail}
MAILE_CODE={Yoir Mail code} MAILE_CODE={Yoir Mail code}
@@ -48,7 +47,8 @@ aliTemplateCode=SMS_xxx
# token # token
TOKEN_KEY=sswada TOKEN_KEY=sswada
# openai # openai
OPENAIKEY={Your openapi key} OPENAIKEY=sk-xxx # 对话用的key
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
# db # db
MONGODB_URI=mongodb://username:password@0.0.0.0:27017/test?authSource=admin MONGODB_URI=mongodb://username:password@0.0.0.0:27017/test?authSource=admin
PG_HOST=0.0.0.0 PG_HOST=0.0.0.0

View File

@@ -10,9 +10,6 @@
# proxy可选 # proxy可选
AXIOS_PROXY_HOST=127.0.0.1 AXIOS_PROXY_HOST=127.0.0.1
AXIOS_PROXY_PORT=7890 AXIOS_PROXY_PORT=7890
# 是否开启队列任务。 1-开启0-关闭请求parentUrl去执行任务,单机时直接填1
queueTask=1
parentUrl=https://hostname/api/openapi/startEvents
# email # email
MY_MAIL=xxx@qq.com MY_MAIL=xxx@qq.com
MAILE_CODE=xxx MAILE_CODE=xxx
@@ -30,7 +27,8 @@ SENSITIVE_CHECK=1
# openai # openai
# OPENAI_BASE_URL=https://api.openai.com/v1 # OPENAI_BASE_URL=https://api.openai.com/v1
# OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉) # OPENAI_BASE_URL_AUTH=可选的安全凭证(不需要的时候,记得去掉)
OPENAIKEY=sk-xxx OPENAIKEY=sk-xxx # 对话用的key
OPENAI_TRAINING_KEY=sk-xxx # 训练用的key
GPT4KEY=sk-xxx GPT4KEY=sk-xxx
# claude # claude
CLAUDE_BASE_URL=calude模型请求地址 CLAUDE_BASE_URL=calude模型请求地址

View File

@@ -1,7 +1,7 @@
import { GET, POST, PUT, DELETE } from '../request'; import { GET, POST, PUT, DELETE } from '../request';
import type { KbItemType } from '@/types/plugin'; import type { KbItemType } from '@/types/plugin';
import { RequestPaging } from '@/types/index'; import { RequestPaging } from '@/types/index';
import { SplitTextTypEnum } from '@/constants/plugin'; import { TrainingTypeEnum } from '@/constants/plugin';
import { KbDataItemType } from '@/types/plugin'; import { KbDataItemType } from '@/types/plugin';
export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string };
@@ -34,11 +34,11 @@ export const getExportDataList = (kbId: string) =>
/** /**
* 获取模型正在拆分数据的数量 * 获取模型正在拆分数据的数量
*/ */
export const getTrainingData = (kbId: string) => export const getTrainingData = (data: { kbId: string; init: boolean }) =>
GET<{ POST<{
splitDataQueue: number; qaListLen: number;
embeddingQueue: number; vectorListLen: number;
}>(`/plugins/kb/data/getTrainingData?kbId=${kbId}`); }>(`/plugins/kb/data/getTrainingData`, data);
export const getKbDataItemById = (dataId: string) => export const getKbDataItemById = (dataId: string) =>
GET(`/plugins/kb/data/getDataById`, { dataId }); GET(`/plugins/kb/data/getDataById`, { dataId });
@@ -69,5 +69,5 @@ export const postSplitData = (data: {
kbId: string; kbId: string;
chunks: string[]; chunks: string[];
prompt: string; prompt: string;
mode: `${SplitTextTypEnum}`; mode: `${TrainingTypeEnum}`;
}) => POST(`/openapi/text/splitText`, data); }) => POST(`/openapi/text/splitText`, data);

View File

@@ -108,27 +108,27 @@ export const ModelDataStatusMap: Record<`${ModelDataStatusEnum}`, string> = {
/* 知识库搜索时的配置 */ /* 知识库搜索时的配置 */
// 搜索方式 // 搜索方式
export enum ModelVectorSearchModeEnum { export enum appVectorSearchModeEnum {
hightSimilarity = 'hightSimilarity', // 高相似度+禁止回复 hightSimilarity = 'hightSimilarity', // 高相似度+禁止回复
lowSimilarity = 'lowSimilarity', // 低相似度 lowSimilarity = 'lowSimilarity', // 低相似度
noContext = 'noContex' // 高相似度+无上下文回复 noContext = 'noContex' // 高相似度+无上下文回复
} }
export const ModelVectorSearchModeMap: Record< export const ModelVectorSearchModeMap: Record<
`${ModelVectorSearchModeEnum}`, `${appVectorSearchModeEnum}`,
{ {
text: string; text: string;
similarity: number; similarity: number;
} }
> = { > = {
[ModelVectorSearchModeEnum.hightSimilarity]: { [appVectorSearchModeEnum.hightSimilarity]: {
text: '高相似度, 无匹配时拒绝回复', text: '高相似度, 无匹配时拒绝回复',
similarity: 0.18 similarity: 0.18
}, },
[ModelVectorSearchModeEnum.noContext]: { [appVectorSearchModeEnum.noContext]: {
text: '高相似度,无匹配时直接回复', text: '高相似度,无匹配时直接回复',
similarity: 0.18 similarity: 0.18
}, },
[ModelVectorSearchModeEnum.lowSimilarity]: { [appVectorSearchModeEnum.lowSimilarity]: {
text: '低相似度匹配', text: '低相似度匹配',
similarity: 0.7 similarity: 0.7
} }
@@ -143,7 +143,7 @@ export const defaultModel: ModelSchema = {
updateTime: Date.now(), updateTime: Date.now(),
chat: { chat: {
relatedKbs: [], relatedKbs: [],
searchMode: ModelVectorSearchModeEnum.hightSimilarity, searchMode: appVectorSearchModeEnum.hightSimilarity,
systemPrompt: '', systemPrompt: '',
temperature: 0, temperature: 0,
chatModel: OpenAiChatEnum.GPT35 chatModel: OpenAiChatEnum.GPT35

View File

@@ -1,14 +1,4 @@
export enum SplitTextTypEnum { export enum TrainingTypeEnum {
'qa' = 'qa', 'qa' = 'qa',
'subsection' = 'subsection' 'subsection' = 'subsection'
} }
export enum PluginTypeEnum {
LLM = 'LLM',
Text = 'Text',
Function = 'Function'
}
export enum PluginParamsTypeEnum {
'Text' = 'text'
}

View File

@@ -5,7 +5,7 @@ import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
import type { ChatItemSimpleType } from '@/types/chat'; import type { ChatItemSimpleType } from '@/types/chat';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { ModelVectorSearchModeEnum } from '@/constants/model'; import { appVectorSearchModeEnum } from '@/constants/model';
import { authModel } from '@/service/utils/auth'; import { authModel } from '@/service/utils/auth';
import { ChatModelMap } from '@/constants/model'; import { ChatModelMap } from '@/constants/model';
import { ChatRoleEnum } from '@/constants/chat'; import { ChatRoleEnum } from '@/constants/chat';
@@ -92,7 +92,8 @@ export async function appKbSearch({
// get vector // get vector
const promptVectors = await openaiEmbedding({ const promptVectors = await openaiEmbedding({
userId, userId,
input input,
type: 'chat'
}); });
// search kb // search kb
@@ -138,7 +139,7 @@ export async function appKbSearch({
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
value: model.chat.systemPrompt value: model.chat.systemPrompt
} }
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext : model.chat.searchMode === appVectorSearchModeEnum.noContext
? { ? {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
@@ -176,7 +177,7 @@ export async function appKbSearch({
const systemPrompt = sliceResult.flat().join('\n').trim(); const systemPrompt = sliceResult.flat().join('\n').trim();
/* 高相似度+不回复 */ /* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) { if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
return { return {
code: 201, code: 201,
rawSearch: [], rawSearch: [],
@@ -190,7 +191,7 @@ export async function appKbSearch({
}; };
} }
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) { if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) {
return { return {
code: 200, code: 200,
rawSearch: [], rawSearch: [],

View File

@@ -1,84 +1,36 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import type { KbDataItemType } from '@/types/plugin'; import type { KbDataItemType } from '@/types/plugin';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { PgClient, insertKbItem } from '@/service/pg'; import { PgClient } from '@/service/pg';
import { authKb } from '@/service/utils/auth'; import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
interface Props {
kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
}
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { const { kbId, data } = req.body as Props;
kbId,
data,
formatLineBreak = true
} = req.body as {
kbId: string;
formatLineBreak?: boolean;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
};
if (!kbId || !Array.isArray(data)) { if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
await connectToDatabase(); await connectToDatabase();
// 凭证校验 // 凭证校验
const { userId } = await authUser({ req }); const { userId } = await authUser({ req });
await authKb({
userId,
kbId
});
// 过滤重复的内容
const searchRes = await Promise.allSettled(
data.map(async ({ q, a = '' }) => {
if (!q) {
return Promise.reject('q为空');
}
if (formatLineBreak) {
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
}
// Exactly the same data, not push
try {
const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
});
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入记录
const insertRes = await insertKbItem({
userId,
kbId,
data: filterData
});
generateVector();
jsonRes(res, { jsonRes(res, {
message: `共插入 ${insertRes.rowCount} 条数据`, data: await pushDataToKb({
data: insertRes.rowCount kbId,
data,
userId
})
}); });
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {
@@ -88,6 +40,32 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
} }
}); });
export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) {
await authKb({
userId,
kbId
});
if (data.length === 0) {
return {
trainingId: ''
};
}
// 插入记录
const { _id } = await TrainingData.create({
userId,
kbId,
vectorList: data
});
generateVector(_id);
return {
trainingId: _id
};
}
export const config = { export const config = {
api: { api: {
bodyParser: { bodyParser: {

View File

@@ -5,10 +5,11 @@ import { ModelDataStatusEnum } from '@/constants/model';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { PgClient } from '@/service/pg'; import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
import { openaiEmbedding } from '../plugin/openaiEmbedding';
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string }; const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string };
if (!dataId) { if (!dataId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
@@ -17,22 +18,24 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 凭证校验 // 凭证校验
const { userId } = await authUser({ req }); const { userId } = await authUser({ req });
// get vector
const vector = await (async () => {
if (q) {
return openaiEmbedding({
userId,
input: [q],
type: 'chat'
});
}
return [];
})();
// 更新 pg 内容.仅修改a不需要更新向量。 // 更新 pg 内容.仅修改a不需要更新向量。
await PgClient.update('modelData', { await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]], where: [['id', dataId], 'AND', ['user_id', userId]],
values: [ values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])]
{ key: 'a', value: a },
...(q
? [
{ key: 'q', value: q },
{ key: 'status', value: ModelDataStatusEnum.waiting }
]
: [])
]
}); });
q && generateVector();
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@@ -1,30 +1,31 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
import { getApiKey } from '@/service/utils/auth'; import { getApiKey } from '@/service/utils/auth';
import { getOpenAIApi } from '@/service/utils/chat/openai'; import { getOpenAIApi } from '@/service/utils/chat/openai';
import { embeddingModel } from '@/constants/model'; import { embeddingModel } from '@/constants/model';
import { axiosConfig } from '@/service/utils/tools'; import { axiosConfig } from '@/service/utils/tools';
import { pushGenerateVectorBill } from '@/service/events/pushBill'; import { pushGenerateVectorBill } from '@/service/events/pushBill';
import { ApiKeyType } from '@/service/utils/auth';
type Props = { type Props = {
input: string[]; input: string[];
type?: ApiKeyType;
}; };
type Response = number[][]; type Response = number[][];
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { userId } = await authUser({ req }); const { userId } = await authUser({ req });
let { input } = req.query as Props; let { input, type } = req.query as Props;
if (!Array.isArray(input)) { if (!Array.isArray(input)) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
jsonRes<Response>(res, { jsonRes<Response>(res, {
data: await openaiEmbedding({ userId, input, mustPay: true }) data: await openaiEmbedding({ userId, input, mustPay: true, type })
}); });
} catch (err) { } catch (err) {
console.log(err); console.log(err);
@@ -38,12 +39,14 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function openaiEmbedding({ export async function openaiEmbedding({
userId, userId,
input, input,
mustPay = false mustPay = false,
type = 'chat'
}: { userId: string; mustPay?: boolean } & Props) { }: { userId: string; mustPay?: boolean } & Props) {
const { userOpenAiKey, systemAuthKey } = await getApiKey({ const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
userId, userId,
mustPay mustPay,
type
}); });
// 获取 chatAPI // 获取 chatAPI

View File

@@ -1,19 +0,0 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { generateQA } from '@/service/events/generateQA';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
generateQA();
generateVector();
jsonRes(res);
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -17,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { input } = req.body as TextPluginRequestParams; const { input } = req.body as TextPluginRequestParams;
const response = await axios({ const response = await axios({
...axiosConfig(getSystemOpenAiKey()), ...axiosConfig(getSystemOpenAiKey('chat')),
method: 'POST', method: 'POST',
url: `/moderations`, url: `/moderations`,
data: { data: {

View File

@@ -1,12 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData } from '@/service/mongo'; import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authKb, authUser } from '@/service/utils/auth'; import { authKb, authUser } from '@/service/utils/auth';
import { generateVector } from '@/service/events/generateVector';
import { generateQA } from '@/service/events/generateQA'; import { generateQA } from '@/service/events/generateQA';
import { insertKbItem } from '@/service/pg'; import { TrainingTypeEnum } from '@/constants/plugin';
import { SplitTextTypEnum } from '@/constants/plugin';
import { withNextCors } from '@/service/utils/tools'; import { withNextCors } from '@/service/utils/tools';
import { pushDataToKb } from '../kb/pushData';
/* split text */ /* split text */
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -15,7 +14,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
kbId: string; kbId: string;
chunks: string[]; chunks: string[];
prompt: string; prompt: string;
mode: `${SplitTextTypEnum}`; mode: `${TrainingTypeEnum}`;
}; };
if (!chunks || !kbId || !prompt) { if (!chunks || !kbId || !prompt) {
throw new Error('参数错误'); throw new Error('参数错误');
@@ -30,29 +29,26 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
userId userId
}); });
if (mode === SplitTextTypEnum.qa) { if (mode === TrainingTypeEnum.qa) {
// 批量QA拆分插入数据 // 批量QA拆分插入数据
await SplitData.create({ const { _id } = await TrainingData.create({
userId, userId,
kbId, kbId,
textList: chunks, qaList: chunks,
prompt prompt
}); });
generateQA(_id);
generateQA(); } else if (mode === TrainingTypeEnum.subsection) {
} else if (mode === SplitTextTypEnum.subsection) { // 分段导入,直接插入向量队列
// 待优化,直接调用另一个接口 const response = await pushDataToKb({
// 插入记录
await insertKbItem({
userId,
kbId, kbId,
data: chunks.map((item) => ({ data: chunks.map((item) => ({ q: item, a: '' })),
q: item, userId
a: ''
}))
}); });
generateVector(); return jsonRes(res, {
data: response
});
} }
jsonRes(res); jsonRes(res);

View File

@@ -1,14 +1,15 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo'; import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth';
import { ModelDataStatusEnum } from '@/constants/model'; import { Types } from 'mongoose';
import { PgClient } from '@/service/pg'; import { generateQA } from '@/service/events/generateQA';
import { generateVector } from '@/service/events/generateVector';
/* 拆分数据成QA */ /* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try { try {
const { kbId } = req.query as { kbId: string }; const { kbId, init = false } = req.body as { kbId: string; init: boolean };
if (!kbId) { if (!kbId) {
throw new Error('参数错误'); throw new Error('参数错误');
} }
@@ -17,29 +18,43 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { userId } = await authUser({ req, authToken: true }); const { userId } = await authUser({ req, authToken: true });
// split queue data // split queue data
const data = await SplitData.find({ const result = await TrainingData.aggregate([
userId, { $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } },
kbId, {
textList: { $exists: true, $not: { $size: 0 } } $project: {
}); qaListLength: { $size: { $ifNull: ['$qaList', []] } },
vectorListLength: { $size: { $ifNull: ['$vectorList', []] } }
// embedding queue data }
const embeddingData = await PgClient.count('modelData', { },
where: [ {
['user_id', userId], $group: {
'AND', _id: null,
['kb_id', kbId], totalQaListLength: { $sum: '$qaListLength' },
'AND', totalVectorListLength: { $sum: '$vectorListLength' }
['status', ModelDataStatusEnum.waiting] }
] }
}); ]);
jsonRes(res, { jsonRes(res, {
data: { data: {
splitDataQueue: data.map((item) => item.textList).flat().length, qaListLen: result[0]?.totalQaListLength || 0,
embeddingQueue: embeddingData vectorListLen: result[0]?.totalVectorListLength || 0
} }
}); });
if (init) {
const list = await TrainingData.find(
{
userId,
kbId
},
'_id'
);
list.forEach((item) => {
generateQA(item._id);
generateVector(item._id);
});
}
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {
code: 500, code: 500,

View File

@@ -91,9 +91,9 @@ const DataCard = ({ kbId }: { kbId: string }) => {
onClose: onCloseSelectCsvModal onClose: onCloseSelectCsvModal
} = useDisclosure(); } = useDisclosure();
const { data: { splitDataQueue = 0, embeddingQueue = 0 } = {}, refetch } = useQuery( const { data: { qaListLen = 0, vectorListLen = 0 } = {}, refetch } = useQuery(
['getModelSplitDataList'], ['getModelSplitDataList'],
() => getTrainingData(kbId), () => getTrainingData({ kbId, init: false }),
{ {
onError(err) { onError(err) {
console.log(err); console.log(err);
@@ -113,7 +113,7 @@ const DataCard = ({ kbId }: { kbId: string }) => {
// interval get data // interval get data
useQuery(['refetchData'], () => refetchData(pageNum), { useQuery(['refetchData'], () => refetchData(pageNum), {
refetchInterval: 5000, refetchInterval: 5000,
enabled: splitDataQueue > 0 || embeddingQueue > 0 enabled: qaListLen > 0 || vectorListLen > 0
}); });
// get al data and export csv // get al data and export csv
@@ -161,7 +161,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
variant={'outline'} variant={'outline'}
mr={[2, 4]} mr={[2, 4]}
size={'sm'} size={'sm'}
onClick={() => refetchData(pageNum)} onClick={() => {
refetchData(pageNum);
getTrainingData({ kbId, init: true });
}}
/> />
<Button <Button
variant={'outline'} variant={'outline'}
@@ -194,10 +197,10 @@ const DataCard = ({ kbId }: { kbId: string }) => {
</Menu> </Menu>
</Flex> </Flex>
<Flex mt={4}> <Flex mt={4}>
{(splitDataQueue > 0 || embeddingQueue > 0) && ( {(qaListLen > 0 || vectorListLen > 0) && (
<Box fontSize={'xs'}> <Box fontSize={'xs'}>
{splitDataQueue > 0 ? `${splitDataQueue}条数据正在拆分,` : ''} {qaListLen > 0 ? `${qaListLen}条数据正在拆分,` : ''}
{embeddingQueue > 0 ? `${embeddingQueue}条数据正在生成索引,` : ''} {vectorListLen > 0 ? `${vectorListLen}条数据正在生成索引,` : ''}
... ...
</Box> </Box>
)} )}

View File

@@ -20,7 +20,8 @@ import { useMutation } from '@tanstack/react-query';
import { postSplitData } from '@/api/plugins/kb'; import { postSplitData } from '@/api/plugins/kb';
import Radio from '@/components/Radio'; import Radio from '@/components/Radio';
import { splitText_token } from '@/utils/file'; import { splitText_token } from '@/utils/file';
import { SplitTextTypEnum } from '@/constants/plugin'; import { TrainingTypeEnum } from '@/constants/plugin';
import { getErrText } from '@/utils/tools';
const fileExtension = '.txt,.doc,.docx,.pdf,.md'; const fileExtension = '.txt,.doc,.docx,.pdf,.md';
@@ -52,7 +53,7 @@ const SelectFileModal = ({
const { toast } = useToast(); const { toast } = useToast();
const [prompt, setPrompt] = useState(''); const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [mode, setMode] = useState<`${SplitTextTypEnum}`>(SplitTextTypEnum.subsection); const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.subsection);
const [fileTextArr, setFileTextArr] = useState<string[]>(['']); const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({ const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({
tokens: 0, tokens: 0,
@@ -113,8 +114,9 @@ const SelectFileModal = ({
prompt: `下面是"${prompt || '一段长文本'}"`, prompt: `下面是"${prompt || '一段长文本'}"`,
mode mode
}); });
toast({ toast({
title: '导入数据成功,需要一段拆解和训练', title: '导入数据成功,需要一段拆解和训练. 重复数据会自动删除',
status: 'success' status: 'success'
}); });
onClose(); onClose();
@@ -130,27 +132,35 @@ const SelectFileModal = ({
const onclickImport = useCallback(async () => { const onclickImport = useCallback(async () => {
setBtnLoading(true); setBtnLoading(true);
let promise = Promise.resolve(); try {
let promise = Promise.resolve();
const splitRes = fileTextArr const splitRes = await Promise.all(
.filter((item) => item) fileTextArr
.map((item) => .filter((item) => item)
splitText_token({ .map((item) =>
text: item, splitText_token({
...modeMap[mode] text: item,
}) ...modeMap[mode]
})
)
); );
setSplitRes({ setSplitRes({
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0), tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks: splitRes.map((item) => item.chunks).flat() chunks: splitRes.map((item) => item.chunks).flat()
}); });
await promise;
openConfirm(mutate)();
} catch (error) {
toast({
status: 'warning',
title: getErrText(error, '拆分文本异常')
});
}
setBtnLoading(false); setBtnLoading(false);
}, [fileTextArr, mode, mutate, openConfirm, toast]);
await promise;
openConfirm(mutate)();
}, [fileTextArr, mode, mutate, openConfirm]);
return ( return (
<Modal isOpen={true} onClose={onClose} isCentered> <Modal isOpen={true} onClose={onClose} isCentered>

View File

@@ -53,10 +53,11 @@ function responseError(err: any) {
} }
/* 创建请求实例 */ /* 创建请求实例 */
const instance = axios.create({ export const instance = axios.create({
timeout: 60000, // 超时时间 timeout: 60000, // 超时时间
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
headers: { headers: {
'content-type': 'application/json' rootkey: process.env.ROOT_KEY
} }
}); });
@@ -75,7 +76,6 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
return instance return instance
.request({ .request({
baseURL: `http://localhost:${process.env.PORT || 3000}/api`,
url, url,
method, method,
data: method === 'GET' ? null : data, data: method === 'GET' ? null : data,
@@ -93,18 +93,30 @@ function request(url: string, data: any, config: ConfigType, method: Method): an
* @param {Object} config * @param {Object} config
* @returns * @returns
*/ */
export function GET<T>(url: string, params = {}, config: ConfigType = {}): Promise<T> { export function GET<T = { data: any }>(
url: string,
params = {},
config: ConfigType = {}
): Promise<T> {
return request(url, params, config, 'GET'); return request(url, params, config, 'GET');
} }
export function POST<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> { export function POST<T = { data: any }>(
url: string,
data = {},
config: ConfigType = {}
): Promise<T> {
return request(url, data, config, 'POST'); return request(url, data, config, 'POST');
} }
export function PUT<T>(url: string, data = {}, config: ConfigType = {}): Promise<T> { export function PUT<T = { data: any }>(
url: string,
data = {},
config: ConfigType = {}
): Promise<T> {
return request(url, data, config, 'PUT'); return request(url, data, config, 'PUT');
} }
export function DELETE<T>(url: string, config: ConfigType = {}): Promise<T> { export function DELETE<T = { data: any }>(url: string, config: ConfigType = {}): Promise<T> {
return request(url, {}, config, 'DELETE'); return request(url, {}, config, 'DELETE');
} }

View File

@@ -1,75 +1,55 @@
import { SplitData } from '@/service/mongo'; import { TrainingData } from '@/service/mongo';
import { getApiKey } from '../utils/auth'; import { getApiKey } from '../utils/auth';
import { OpenAiChatEnum } from '@/constants/model'; import { OpenAiChatEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { openaiError2 } from '../errorCode'; import { openaiError2 } from '../errorCode';
import { insertKbItem } from '@/service/pg';
import { SplitDataSchema } from '@/types/mongoSchema';
import { modelServiceToolMap } from '../utils/chat'; import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat'; import { ChatRoleEnum } from '@/constants/chat';
import { getErrText } from '@/utils/tools';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { ERROR_ENUM } from '../errorCode';
export async function generateQA(next = false): Promise<any> { // 每次最多选 1 组
if (process.env.queueTask !== '1') { const listLen = 1;
try {
fetch(process.env.parentUrl || '');
} catch (error) {
console.log('parentUrl fetch error', error);
}
return;
}
if (global.generatingQA === true && !next) return;
global.generatingQA = true;
let dataId = null;
export async function generateQA(trainingId: string): Promise<any> {
try { try {
// 找出一个需要生成的 dataItem // 找出一个需要生成的 dataItem (4分钟锁)
const data = await SplitData.aggregate([ const data = await TrainingData.findOneAndUpdate(
{ $match: { textList: { $exists: true, $ne: [] } } }, {
{ $sample: { size: 1 } } _id: trainingId,
]); lockTime: { $lte: Date.now() - 4 * 60 * 1000 }
},
{
lockTime: new Date()
}
);
const dataItem: SplitDataSchema = data[0]; if (!data || data.qaList.length === 0) {
await TrainingData.findOneAndDelete({
if (!dataItem) { _id: trainingId,
console.log('没有需要生成 QA 的数据'); qaList: [],
global.generatingQA = false; vectorList: []
return;
}
dataId = dataItem._id;
// 获取 5 个源文本
const textList: string[] = dataItem.textList.slice(-5);
// 获取 openapi Key
let userOpenAiKey = '',
systemAuthKey = '';
try {
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
userOpenAiKey = key.userOpenAiKey;
systemAuthKey = key.systemAuthKey;
} catch (err: any) {
// 余额不够了, 清空该记录
await SplitData.findByIdAndUpdate(dataItem._id, {
textList: [],
errorText: getErrText(err, '获取 OpenAi Key 失败')
}); });
generateQA(true);
return; return;
} }
console.log(`正在生成一组QA, 包含 ${textList.length} 组文本。ID: ${dataItem._id}`); const qaList: string[] = data.qaList.slice(-listLen);
// 余额校验并获取 openapi Key
const { userOpenAiKey, systemAuthKey } = await getApiKey({
model: OpenAiChatEnum.GPT35,
userId: data.userId,
type: 'training'
});
console.log(`正在生成一组QA, 包含 ${qaList.length} 组文本。ID: ${data._id}`);
const startTime = Date.now(); const startTime = Date.now();
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
const response = await Promise.allSettled( const response = await Promise.all(
textList.map((text) => qaList.map((text) =>
modelServiceToolMap[OpenAiChatEnum.GPT35] modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({ .chatCompletion({
apiKey: userOpenAiKey || systemAuthKey, apiKey: userOpenAiKey || systemAuthKey,
@@ -78,7 +58,7 @@ export async function generateQA(next = false): Promise<any> {
{ {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
value: `你是出题人 value: `你是出题人
${dataItem.prompt || '下面是"一段长文本"'} ${data.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1: 从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1: A1:
Q2: Q2:
@@ -98,7 +78,7 @@ A2:
// 计费 // 计费
pushSplitDataBill({ pushSplitDataBill({
isPay: !userOpenAiKey && result.length > 0, isPay: !userOpenAiKey && result.length > 0,
userId: dataItem.userId, userId: data.userId,
type: BillTypeEnum.QA, type: BillTypeEnum.QA,
textLen: responseMessages.map((item) => item.value).join('').length, textLen: responseMessages.map((item) => item.value).join('').length,
totalTokens totalTokens
@@ -116,57 +96,59 @@ A2:
) )
); );
// 获取成功的回答 const responseList = response.map((item) => item.result).flat();
const successResponse: {
rawContent: string;
result: {
q: string;
a: string;
}[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const resultList = successResponse.map((item) => item.result).flat(); // 创建 向量生成 队列
pushDataToKb({
kbId: data.kbId,
data: responseList,
userId: data.userId
});
await Promise.allSettled([ // 删除 QA 队列。如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
// 删掉后5个数据 if (data.vectorList.length <= listLen) {
SplitData.findByIdAndUpdate(dataItem._id, { await TrainingData.findByIdAndDelete(data._id);
textList: dataItem.textList.slice(0, -5)
}),
// 生成的内容插入 pg
insertKbItem({
userId: dataItem.userId,
kbId: dataItem.kbId,
data: resultList
})
]);
console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
generateQA(true);
generateVector();
} catch (error: any) {
// log
if (error?.response) {
console.log('openai error: 生成QA错误');
console.log(error.response?.status, error.response?.statusText, error.response?.data);
} else { } else {
console.log('生成QA错误:', error); await TrainingData.findByIdAndUpdate(data._id, {
qaList: data.qaList.slice(0, -listLen),
lockTime: new Date('2000/1/1')
});
} }
// 没有余额或者凭证错误时,拒绝任务 console.log('生成QA成功time:', `${(Date.now() - startTime) / 1000}s`);
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
console.log(openaiError2[error?.response?.data?.error?.type], '删除QA任务');
await SplitData.findByIdAndUpdate(dataId, { generateQA(trainingId);
textList: [], } catch (err: any) {
errorText: 'api 余额不足' // log
}); if (err?.response) {
console.log('openai error: 生成QA错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成QA错误:', err);
}
generateQA(true); // openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId);
return; return;
} }
// unlock
await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
});
// 频率限制
if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制30s后尝试');
return setTimeout(() => {
generateQA(trainingId);
}, 30000);
}
setTimeout(() => { setTimeout(() => {
generateQA(true); generateQA(trainingId);
}, 1000); }, 1000);
} }
} }

View File

@@ -1,107 +1,137 @@
import { getApiKey } from '../utils/auth';
import { openaiError2 } from '../errorCode'; import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg'; import { insertKbItem, PgClient } from '@/service/pg';
import { getErrText } from '@/utils/tools';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
import { TrainingData } from '../models/trainingData';
import { ERROR_ENUM } from '../errorCode';
export async function generateVector(next = false): Promise<any> { // 每次最多选 5 组
if (process.env.queueTask !== '1') { const listLen = 5;
try {
fetch(process.env.parentUrl || '');
} catch (error) {
console.log('parentUrl fetch error', error);
}
return;
}
if (global.generatingVector && !next) return;
global.generatingVector = true;
let dataId = null;
/* 索引生成队列。每导入一次,就是一个单独的线程 */
export async function generateVector(trainingId: string): Promise<any> {
try { try {
// 找出一个 status = waiting 的数据 // 找出一个需要生成的 dataItem (2分钟锁)
const searchRes = await PgClient.select('modelData', { const data = await TrainingData.findOneAndUpdate(
fields: ['id', 'q', 'user_id'], {
where: [['status', 'waiting']], _id: trainingId,
limit: 1 lockTime: { $lte: Date.now() - 2 * 60 * 1000 }
}); },
{
if (searchRes.rowCount === 0) { lockTime: new Date()
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
const dataItem: { id: string; q: string; userId: string } = {
id: searchRes.rows[0].id,
q: searchRes.rows[0].q,
userId: searchRes.rows[0].user_id
};
dataId = dataItem.id;
// 获取 openapi Key
try {
await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId });
} catch (err: any) {
await PgClient.delete('modelData', {
where: [['id', dataId]]
});
getErrText(err, '获取 OpenAi Key 失败');
return generateVector(true);
}
// 生成词向量
const vectors = await openaiEmbedding({
input: [dataItem.q],
userId: dataItem.userId
});
// 更新 pg 向量和状态数据
await PgClient.update('modelData', {
values: [
{ key: 'vector', value: `[${vectors[0]}]` },
{ key: 'status', value: `ready` }
],
where: [['id', dataId]]
});
console.log(`生成向量成功: ${dataItem.id}`);
generateVector(true);
} catch (error: any) {
// log
if (error?.response) {
console.log('openai error: 生成向量错误');
console.log(error.response?.status, error.response?.statusText, error.response?.data);
} else {
console.log('生成向量错误:', error);
}
// 没有余额或者凭证错误时,拒绝任务
if (dataId && openaiError2[error?.response?.data?.error?.type]) {
console.log('删除向量生成任务记录');
try {
await PgClient.delete('modelData', {
where: [['id', dataId]]
});
} catch (error) {
error;
} }
generateVector(true); );
if (!data) {
await TrainingData.findOneAndDelete({
_id: trainingId,
qaList: [],
vectorList: []
});
return; return;
} }
if (error?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制1分钟后尝试'); const userId = String(data.userId);
// 限制次数1分钟后再试 const kbId = String(data.kbId);
setTimeout(() => {
generateVector(true); const dataItems: { q: string; a: string }[] = data.vectorList.slice(-listLen).map((item) => ({
}, 60000); q: item.q,
a: item.a
}));
// 过滤重复的 qa 内容
const searchRes = await Promise.allSettled(
dataItems.map(async ({ q, a = '' }) => {
if (!q) {
return Promise.reject('q为空');
}
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
// Exactly the same data, not push
try {
const count = await PgClient.count('modelData', {
where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
});
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
error;
}
return Promise.resolve({
q,
a
});
})
);
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
if (filterData.length > 0) {
// 生成词向量
const vectors = await openaiEmbedding({
input: filterData.map((item) => item.q),
userId,
type: 'training'
});
// 生成结果插入到 pg
await insertKbItem({
userId,
kbId,
data: vectors.map((vector, i) => ({
q: filterData[i].q,
a: filterData[i].a,
vector
}))
});
}
// 删除 mongo 训练队列. 如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个
if (data.vectorList.length <= listLen) {
await TrainingData.findByIdAndDelete(trainingId);
console.log(`全部向量生成完毕: ${trainingId}`);
} else {
await TrainingData.findByIdAndUpdate(trainingId, {
vectorList: data.vectorList.slice(0, -listLen),
lockTime: new Date('2000/1/1')
});
console.log(`生成向量成功: ${trainingId}`);
generateVector(trainingId);
}
} catch (err: any) {
// log
if (err?.response) {
console.log('openai error: 生成向量错误');
console.log(err.response?.status, err.response?.statusText, err.response?.data);
} else {
console.log('生成向量错误:', err);
}
// openai 账号异常或者账号余额不足,删除任务
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务');
await TrainingData.findByIdAndDelete(trainingId);
return; return;
} }
// unlock
await TrainingData.findByIdAndUpdate(trainingId, {
lockTime: new Date('2000/1/1')
});
// 频率限制
if (err?.response?.statusText === 'Too Many Requests') {
console.log('生成向量次数限制30s后尝试');
return setTimeout(() => {
generateVector(trainingId);
}, 30000);
}
setTimeout(() => { setTimeout(() => {
generateVector(true); generateVector(trainingId);
}, 1000); }, 1000);
} }
} }

View File

@@ -2,7 +2,7 @@ import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSchema as ModelType } from '@/types/mongoSchema'; import { ModelSchema as ModelType } from '@/types/mongoSchema';
import { import {
ModelVectorSearchModeMap, ModelVectorSearchModeMap,
ModelVectorSearchModeEnum, appVectorSearchModeEnum,
ChatModelMap, ChatModelMap,
OpenAiChatEnum OpenAiChatEnum
} from '@/constants/model'; } from '@/constants/model';
@@ -40,7 +40,7 @@ const ModelSchema = new Schema({
// knowledge base search mode // knowledge base search mode
type: String, type: String,
enum: Object.keys(ModelVectorSearchModeMap), enum: Object.keys(ModelVectorSearchModeMap),
default: ModelVectorSearchModeEnum.hightSimilarity default: appVectorSearchModeEnum.hightSimilarity
}, },
systemPrompt: { systemPrompt: {
// 系统提示词 // 系统提示词

View File

@@ -1,32 +0,0 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { SplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
prompt: {
// 拆分时的提示词
type: String,
required: true
},
kbId: {
type: Schema.Types.ObjectId,
ref: 'kb',
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@@ -0,0 +1,38 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema';
// pgList and vectorList, Only one of them will work
const TrainingDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
kbId: {
type: Schema.Types.ObjectId,
ref: 'kb',
required: true
},
lockTime: {
type: Date,
default: () => new Date('2000/1/1')
},
vectorList: {
type: [{ q: String, a: String }],
default: []
},
prompt: {
// 拆分时的提示词
type: String,
default: ''
},
qaList: {
type: [String],
default: []
}
});
export const TrainingData: MongoModel<TrainingDateType> =
models['trainingData'] || model('trainingData', TrainingDataSchema);

View File

@@ -2,6 +2,7 @@ import mongoose from 'mongoose';
import { generateQA } from './events/generateQA'; import { generateQA } from './events/generateQA';
import { generateVector } from './events/generateVector'; import { generateVector } from './events/generateVector';
import tunnel from 'tunnel'; import tunnel from 'tunnel';
import { TrainingData } from './mongo';
/** /**
* 连接 MongoDB 数据库 * 连接 MongoDB 数据库
@@ -27,9 +28,6 @@ export async function connectToDatabase(): Promise<void> {
global.mongodb = null; global.mongodb = null;
} }
generateQA();
generateVector();
// 创建代理对象 // 创建代理对象
if (process.env.AXIOS_PROXY_HOST && process.env.AXIOS_PROXY_PORT) { if (process.env.AXIOS_PROXY_HOST && process.env.AXIOS_PROXY_PORT) {
global.httpsAgent = tunnel.httpsOverHttp({ global.httpsAgent = tunnel.httpsOverHttp({
@@ -39,6 +37,34 @@ export async function connectToDatabase(): Promise<void> {
} }
}); });
} }
startTrain();
// 5 分钟后解锁不正常的数据,并触发开始训练
setTimeout(async () => {
await TrainingData.updateMany(
{
lockTime: { $lte: Date.now() - 5 * 60 * 1000 }
},
{
lockTime: new Date('2000/1/1')
}
);
startTrain();
}, 5 * 60 * 1000);
}
async function startTrain() {
const qa = await TrainingData.find({
qaList: { $exists: true, $ne: [] }
});
qa.map((item) => generateQA(String(item._id)));
const vector = await TrainingData.find({
vectorList: { $exists: true, $ne: [] }
});
vector.map((item) => generateVector(String(item._id)));
} }
export * from './models/authCode'; export * from './models/authCode';
@@ -47,7 +73,7 @@ export * from './models/model';
export * from './models/user'; export * from './models/user';
export * from './models/bill'; export * from './models/bill';
export * from './models/pay'; export * from './models/pay';
export * from './models/splitData'; export * from './models/trainingData';
export * from './models/openapi'; export * from './models/openapi';
export * from './models/promotionRecord'; export * from './models/promotionRecord';
export * from './models/collection'; export * from './models/collection';

View File

@@ -1,5 +1,6 @@
import { Pool } from 'pg'; import { Pool } from 'pg';
import type { QueryResultRow } from 'pg'; import type { QueryResultRow } from 'pg';
import { ModelDataStatusEnum } from '@/constants/model';
export const connectPg = async () => { export const connectPg = async () => {
if (global.pgClient) { if (global.pgClient) {
@@ -168,6 +169,7 @@ export const insertKbItem = ({
userId: string; userId: string;
kbId: string; kbId: string;
data: { data: {
vector: number[];
q: string; q: string;
a: string; a: string;
}[]; }[];
@@ -178,7 +180,8 @@ export const insertKbItem = ({
{ key: 'kb_id', value: kbId }, { key: 'kb_id', value: kbId },
{ key: 'q', value: item.q }, { key: 'q', value: item.q },
{ key: 'a', value: item.a }, { key: 'a', value: item.a },
{ key: 'status', value: 'waiting' } { key: 'vector', value: `[${item.vector}]` },
{ key: 'status', value: ModelDataStatusEnum.ready }
]) ])
}); });
}; };

View File

@@ -5,12 +5,14 @@ import { Chat, Model, OpenApi, User, ShareChat, KB } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import type { ChatItemSimpleType } from '@/types/chat'; import type { ChatItemSimpleType } from '@/types/chat';
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { ClaudeEnum, defaultModel } from '@/constants/model'; import { ClaudeEnum, defaultModel, embeddingModel, EmbeddingModelType } from '@/constants/model';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode'; import { ERROR_ENUM } from '../errorCode';
import { ChatModelType, OpenAiChatEnum } from '@/constants/model'; import { ChatModelType, OpenAiChatEnum } from '@/constants/model';
import { hashPassword } from '@/service/utils/tools'; import { hashPassword } from '@/service/utils/tools';
export type ApiKeyType = 'training' | 'chat';
export const parseCookie = (cookie?: string): Promise<string> => { export const parseCookie = (cookie?: string): Promise<string> => {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
// 获取 cookie // 获取 cookie
@@ -118,9 +120,15 @@ export const authUser = async ({
}; };
/* random get openai api key */ /* random get openai api key */
export const getSystemOpenAiKey = () => { export const getSystemOpenAiKey = (type: ApiKeyType) => {
const keys = (() => {
if (type === 'training') {
return process.env.OPENAI_TRAINING_KEY?.split(',') || [];
}
return process.env.OPENAIKEY?.split(',') || [];
})();
// 纯字符串类型 // 纯字符串类型
const keys = process.env.OPENAIKEY?.split(',') || [];
const i = Math.floor(Math.random() * keys.length); const i = Math.floor(Math.random() * keys.length);
return keys[i] || (process.env.OPENAIKEY as string); return keys[i] || (process.env.OPENAIKEY as string);
}; };
@@ -129,11 +137,13 @@ export const getSystemOpenAiKey = () => {
export const getApiKey = async ({ export const getApiKey = async ({
model, model,
userId, userId,
mustPay = false mustPay = false,
type = 'chat'
}: { }: {
model: ChatModelType; model: ChatModelType;
userId: string; userId: string;
mustPay?: boolean; mustPay?: boolean;
type?: ApiKeyType;
}) => { }) => {
const user = await User.findById(userId); const user = await User.findById(userId);
if (!user) { if (!user) {
@@ -143,7 +153,7 @@ export const getApiKey = async ({
const keyMap = { const keyMap = {
[OpenAiChatEnum.GPT35]: { [OpenAiChatEnum.GPT35]: {
userOpenAiKey: user.openaiKey || '', userOpenAiKey: user.openaiKey || '',
systemAuthKey: getSystemOpenAiKey() as string systemAuthKey: getSystemOpenAiKey(type) as string
}, },
[OpenAiChatEnum.GPT4]: { [OpenAiChatEnum.GPT4]: {
userOpenAiKey: user.openaiKey || '', userOpenAiKey: user.openaiKey || '',

View File

@@ -5,8 +5,6 @@ import type { Pool } from 'pg';
declare global { declare global {
var mongodb: Mongoose | string | null; var mongodb: Mongoose | string | null;
var pgClient: Pool | null; var pgClient: Pool | null;
var generatingQA: boolean;
var generatingVector: boolean;
var httpsAgent: Agent; var httpsAgent: Agent;
var particlesJS: any; var particlesJS: any;
var grecaptcha: any; var grecaptcha: any;

View File

@@ -1,6 +1,6 @@
import { ModelStatusEnum } from '@/constants/model'; import { ModelStatusEnum } from '@/constants/model';
import type { ModelSchema, kbSchema } from './mongoSchema'; import type { ModelSchema, kbSchema } from './mongoSchema';
import { ChatModelType, ModelVectorSearchModeEnum } from '@/constants/model'; import { ChatModelType, appVectorSearchModeEnum } from '@/constants/model';
export type ModelListItemType = { export type ModelListItemType = {
_id: string; _id: string;

View File

@@ -2,12 +2,13 @@ import type { ChatItemType } from './chat';
import { import {
ModelStatusEnum, ModelStatusEnum,
ModelNameEnum, ModelNameEnum,
ModelVectorSearchModeEnum, appVectorSearchModeEnum,
ChatModelType, ChatModelType,
EmbeddingModelType EmbeddingModelType
} from '@/constants/model'; } from '@/constants/model';
import type { DataType } from './data'; import type { DataType } from './data';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
import { TrainingTypeEnum } from '@/constants/plugin';
export interface UserModelSchema { export interface UserModelSchema {
_id: string; _id: string;
@@ -44,7 +45,7 @@ export interface ModelSchema {
updateTime: number; updateTime: number;
chat: { chat: {
relatedKbs: string[]; relatedKbs: string[];
searchMode: `${ModelVectorSearchModeEnum}`; searchMode: `${appVectorSearchModeEnum}`;
systemPrompt: string; systemPrompt: string;
temperature: number; temperature: number;
chatModel: ChatModelType; // 聊天时用的模型,训练后就是训练的模型 chatModel: ChatModelType; // 聊天时用的模型,训练后就是训练的模型
@@ -68,13 +69,14 @@ export interface CollectionSchema {
export type ModelDataType = 0 | 1; export type ModelDataType = 0 | 1;
export interface SplitDataSchema { export interface TrainingDataSchema {
_id: string; _id: string;
userId: string; userId: string;
kbId: string; kbId: string;
lockTime: Date;
vectorList: { q: string; a: string }[];
prompt: string; prompt: string;
errorText: string; qaList: string[];
textList: string[];
} }
export interface ChatSchema { export interface ChatSchema {

15
src/types/plugin.d.ts vendored
View File

@@ -1,5 +1,4 @@
import type { kbSchema } from './mongoSchema'; import type { kbSchema } from './mongoSchema';
import { PluginTypeEnum } from '@/constants/plugin';
/* kb type */ /* kb type */
export interface KbItemType extends kbSchema { export interface KbItemType extends kbSchema {
@@ -16,20 +15,6 @@ export interface KbDataItemType {
userId: string; userId: string;
} }
/* plugin */
export interface PluginConfig {
name: string;
desc: string;
url: string;
category: `${PluginTypeEnum}`;
uniPrice: 22; // 1k token
params: [
{
type: '';
}
];
}
export type TextPluginRequestParams = { export type TextPluginRequestParams = {
input: string; input: string;
}; };

View File

@@ -145,7 +145,7 @@ export const fileDownload = ({
* slideLen - The size of the before and after Text * slideLen - The size of the before and after Text
* maxLen > slideLen * maxLen > slideLen
*/ */
export const splitText_token = ({ export const splitText_token = async ({
text, text,
maxLen, maxLen,
slideLen slideLen
@@ -154,32 +154,39 @@ export const splitText_token = ({
maxLen: number; maxLen: number;
slideLen: number; slideLen: number;
}) => { }) => {
const enc = getOpenAiEncMap()['gpt-3.5-turbo']; try {
// filter empty text. encode sentence const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
const encodeText = enc.encode(text); // filter empty text. encode sentence
const encodeText = enc.encode(text);
const chunks: string[] = []; const chunks: string[] = [];
let tokens = 0; let tokens = 0;
let startIndex = 0; let startIndex = 0;
let endIndex = Math.min(startIndex + maxLen, encodeText.length); let endIndex = Math.min(startIndex + maxLen, encodeText.length);
let chunkEncodeArr = encodeText.slice(startIndex, endIndex); let chunkEncodeArr = encodeText.slice(startIndex, endIndex);
const decoder = new TextDecoder(); const decoder = new TextDecoder();
while (startIndex < encodeText.length) { while (startIndex < encodeText.length) {
tokens += chunkEncodeArr.length; tokens += chunkEncodeArr.length;
chunks.push(decoder.decode(enc.decode(chunkEncodeArr))); chunks.push(decoder.decode(enc.decode(chunkEncodeArr)));
startIndex += maxLen - slideLen; startIndex += maxLen - slideLen;
endIndex = Math.min(startIndex + maxLen, encodeText.length); endIndex = Math.min(startIndex + maxLen, encodeText.length);
chunkEncodeArr = encodeText.slice(Math.min(encodeText.length - slideLen, startIndex), endIndex); chunkEncodeArr = encodeText.slice(
Math.min(encodeText.length - slideLen, startIndex),
endIndex
);
}
return {
chunks,
tokens
};
} catch (error) {
return Promise.reject(error);
} }
return {
chunks,
tokens
};
}; };
export const fileToBase64 = (file: File) => { export const fileToBase64 = (file: File) => {

View File

@@ -26,7 +26,7 @@ export const authGoogleToken = async (data: {
const res = await axios.post<{ score?: number }>( const res = await axios.post<{ score?: number }>(
`https://www.recaptcha.net/recaptcha/api/siteverify?${Obj2Query(data)}` `https://www.recaptcha.net/recaptcha/api/siteverify?${Obj2Query(data)}`
); );
if (res.data.score && res.data.score >= 0.9) { if (res.data.score && res.data.score >= 0.5) {
return Promise.resolve(''); return Promise.resolve('');
} }
return Promise.reject('非法环境'); return Promise.reject('非法环境');