feat: config vector model and qa model

This commit is contained in:
archer
2023-08-25 15:00:51 +08:00
parent a9970dd694
commit 6d93059e25
35 changed files with 337 additions and 196 deletions

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { connectToDatabase, TrainingData, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
@@ -14,7 +14,6 @@ export type DateItemType = { a: string; q: string; source?: string };
export type Props = {
kbId: string;
data: DateItemType[];
model: string;
mode: `${TrainingModeEnum}`;
prompt?: string;
};
@@ -30,23 +29,12 @@ const modeMaxToken = {
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt, model } = req.body as Props;
const { kbId, data, mode, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data) || !model) {
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
// auth model
if (mode === TrainingModeEnum.qa && !global.qaModels.find((item) => item.model === model)) {
throw new Error('不支持的 QA 拆分模型');
}
if (
mode === TrainingModeEnum.index &&
!global.vectorModels.find((item) => item.model === model)
) {
throw new Error('不支持的向量生成模型');
}
await connectToDatabase();
// 凭证校验
@@ -58,8 +46,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
data,
userId,
mode,
prompt,
model
prompt
})
});
} catch (err) {
@@ -75,8 +62,7 @@ export async function pushDataToKb({
kbId,
data,
mode,
prompt,
model
prompt
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
@@ -152,17 +138,24 @@ export async function pushDataToKb({
.filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value);
const vectorModel = await (async () => {
if (mode === TrainingModeEnum.index) {
return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model;
}
return global.vectorModels[0].model;
})();
// 插入记录
await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
model,
source: item.source,
userId,
kbId,
mode,
prompt
prompt,
vectorModel
}))
);

View File

@@ -2,15 +2,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import type { CreateKbParams } from '@/api/request/kb';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { name, tags } = req.body as {
name: string;
tags: string[];
};
const { name, tags, avatar, vectorModel } = req.body as CreateKbParams;
if (!name) {
if (!name || !vectorModel) {
throw new Error('缺少参数');
}
@@ -22,7 +20,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const { _id } = await KB.create({
name,
userId,
tags
tags,
vectorModel,
avatar
});
jsonRes(res, { data: _id });

View File

@@ -2,6 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { getModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -33,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
avatar: data.avatar,
name: data.name,
userId: data.userId,
model: data.model,
vectorModelName: getModel(data.vectorModel)?.name || 'Unknown',
tags: data.tags.join(' ')
}
});

View File

@@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, KB } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { KbListItemType } from '@/types/plugin';
import { getModel } from '@/service/utils/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -15,7 +16,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
{
userId
},
'_id avatar name tags'
'_id avatar name tags vectorModel'
).sort({ updateTime: -1 });
const data = await Promise.all(
@@ -23,7 +24,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
_id: item._id,
avatar: item.avatar,
name: item.name,
tags: item.tags
tags: item.tags,
vectorModelName: getModel(item.vectorModel)?.name || 'UnKnow'
}))
);

View File

@@ -10,7 +10,7 @@ import {
export type InitDateResponse = {
chatModels: ChatModelItemType[];
qaModels: QAModelItemType[];
qaModel: QAModelItemType;
vectorModels: VectorModelItemType[];
feConfigs: FeConfigsType;
};
@@ -23,7 +23,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
data: {
feConfigs: global.feConfigs,
chatModels: global.chatModels,
qaModels: global.qaModels,
qaModel: global.qaModel,
vectorModels: global.vectorModels
}
});
@@ -69,14 +69,13 @@ const defaultChatModels = [
price: 0
}
];
const defaultQAModels = [
{
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,
price: 0
}
];
const defaultQAModel = {
model: 'gpt-3.5-turbo-16k',
name: 'GPT35-16k',
maxToken: 16000,
price: 0
};
const defaultVectorModels = [
{
model: 'text-embedding-ada-002',
@@ -95,7 +94,7 @@ export async function getInitConfig() {
global.systemEnv = res.SystemParams || defaultSystemEnv;
global.feConfigs = res.FeConfig || defaultFeConfigs;
global.chatModels = res.ChatModels || defaultChatModels;
global.qaModels = res.QAModels || defaultQAModels;
global.qaModel = res.QAModel || defaultQAModel;
global.vectorModels = res.VectorModels || defaultVectorModels;
} catch (error) {
setDefaultData();
@@ -107,6 +106,6 @@ export function setDefaultData() {
global.systemEnv = defaultSystemEnv;
global.feConfigs = defaultFeConfigs;
global.chatModels = defaultChatModels;
global.qaModels = defaultQAModels;
global.qaModel = defaultQAModel;
global.vectorModels = defaultVectorModels;
}

View File

@@ -100,9 +100,8 @@ export async function registerUser({
username,
avatar,
password: nanoid(),
inviterId
inviterId: inviterId ? inviterId : undefined
});
console.log(response, '-=-=-=');
// 根据 id 获取用户信息
const user = await User.findById(response._id);