new framwork

This commit is contained in:
archer
2023-06-09 12:57:42 +08:00
parent d9450bd7ee
commit ba9d9c3d5f
263 changed files with 12269 additions and 11599 deletions

View File

@@ -0,0 +1,209 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import type { ChatItemSimpleType } from '@/types/chat';
import type { ModelSchema } from '@/types/mongoSchema';
import { appVectorSearchModeEnum } from '@/constants/model';
import { authModel } from '@/service/utils/auth';
import { ChatModelMap } from '@/constants/model';
import { ChatRoleEnum } from '@/constants/chat';
import { openaiEmbedding } from '../plugin/openaiEmbedding';
import { modelToolMap } from '@/utils/plugin';
export type QuoteItemType = {
id: string;
q: string;
a: string;
source?: string;
};
type Props = {
prompts: ChatItemSimpleType[];
similarity: number;
appId: string;
};
type Response = {
code: 200 | 201;
rawSearch: QuoteItemType[];
guidePrompt: string;
searchPrompts: {
obj: ChatRoleEnum;
value: string;
}[];
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { userId } = await authUser({ req });
if (!userId) {
throw new Error('userId is empty');
}
const { prompts, similarity, appId } = req.body as Props;
if (!similarity || !Array.isArray(prompts) || !appId) {
throw new Error('params is error');
}
// auth model
const { model } = await authModel({
modelId: appId,
userId
});
const result = await appKbSearch({
model,
userId,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity
});
jsonRes<Response>(res, {
data: result
});
} catch (err) {
console.log(err);
jsonRes(res, {
code: 500,
error: err
});
}
});
export async function appKbSearch({
model,
userId,
fixedQuote,
prompt,
similarity
}: {
model: ModelSchema;
userId: string;
fixedQuote: QuoteItemType[];
prompt: ChatItemSimpleType;
similarity: number;
}): Promise<Response> {
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// get vector
const promptVector = await openaiEmbedding({
userId,
input: [prompt.value],
type: 'chat'
});
// search kb
const res: any = await PgClient.query(
`BEGIN;
select id,q,a,source from modelData where kb_id IN (${model.chat.relatedKbs
.map((item) => `'${item}'`)
.join(',')}) AND vector <#> '[${promptVector[0]}]' < -${similarity} order by vector <#> '[${
promptVector[0]
}]' limit 8;
COMMIT;`
);
const searchRes: QuoteItemType[] = res?.[1]?.rows || [];
// filter same search result
const idSet = new Set<string>();
const filterSearch = [
...searchRes.slice(0, 3),
...fixedQuote.slice(0, 2),
...searchRes.slice(3),
...fixedQuote.slice(2, 5)
].filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
});
// 计算固定提示词的 token 数量
const guidePrompt = model.chat.systemPrompt // user system prompt
? {
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
: model.chat.searchMode === appVectorSearchModeEnum.noContext
? {
obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
}
: {
obj: ChatRoleEnum.System,
value: `玩一个问答游戏,规则为:
1.你完全忘记你已有的知识
2.你只回答关于"${model.name}"的问题
3.你只从知识库中选择内容进行回答
4.如果问题不在知识库中,你会回答:"我不知道。"
请务必遵守规则`
};
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: [guidePrompt]
});
const sliceResult = modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens,
messages: filterSearch.map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value);
// slice filterSearch
const rawSearch = filterSearch.slice(0, sliceResult.length);
// system prompt
const systemPrompt = sliceResult.join('\n').trim();
/* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
return {
code: 201,
rawSearch: [],
guidePrompt: '',
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: '对不起,你的问题不在知识库中。'
}
]
};
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.noContext) {
return {
code: 200,
rawSearch: [],
guidePrompt: model.chat.systemPrompt || '',
searchPrompts: model.chat.systemPrompt
? [
{
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
]
: []
};
}
return {
code: 200,
rawSearch,
guidePrompt: guidePrompt.value || '',
searchPrompts: [
{
obj: ChatRoleEnum.System,
value: `知识库:${systemPrompt}`
},
guidePrompt
]
};
}

View File

@@ -0,0 +1,32 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { dataId } = req.query as {
dataId: string;
};
if (!dataId) {
throw new Error('缺少参数');
}
// 凭证校验
const { userId } = await authUser({ req });
await PgClient.delete('modelData', {
where: [['user_id', userId], 'AND', ['id', dataId]]
});
jsonRes(res);
} catch (err) {
console.log(err);
jsonRes(res, {
code: 500,
error: err
});
}
});

View File

@@ -0,0 +1,149 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { authKb } from '@/service/utils/auth';
import { withNextCors } from '@/service/utils/tools';
import { TrainingModeEnum } from '@/constants/plugin';
import { startQueue } from '@/service/utils/tools';
import { PgClient } from '@/service/pg';
type DateItemType = { a: string; q: string; source?: string };
export type Props = {
kbId: string;
data: DateItemType[];
mode: `${TrainingModeEnum}`;
prompt?: string;
};
export type Response = {
insertLen: number;
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { kbId, data, mode, prompt } = req.body as Props;
if (!kbId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
await connectToDatabase();
// 凭证校验
const { userId } = await authUser({ req });
jsonRes<Response>(res, {
data: await pushDataToKb({
kbId,
data,
userId,
mode,
prompt
})
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
});
export async function pushDataToKb({
userId,
kbId,
data,
mode,
prompt
}: { userId: string } & Props): Promise<Response> {
await authKb({
userId,
kbId
});
// 过滤重复的 qa 内容
const set = new Set();
const filterData: DateItemType[] = [];
data.forEach((item) => {
const text = item.q + item.a;
if (!set.has(text)) {
filterData.push(item);
set.add(text);
}
});
// 数据库去重
const insertData = (
await Promise.allSettled(
filterData.map(async ({ q, a = '', source }) => {
if (mode !== TrainingModeEnum.index) {
return Promise.resolve({
q,
a,
source
});
}
if (!q) {
return Promise.reject('q为空');
}
q = q.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
a = a.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// Exactly the same data, not push
try {
const { rows } = await PgClient.query(`
SELECT COUNT(*) > 0 AS exists
FROM modelData
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
`);
const exists = rows[0]?.exists || false;
if (exists) {
return Promise.reject('已经存在');
}
} catch (error) {
console.log(error);
error;
}
return Promise.resolve({
q,
a,
source
});
})
)
)
.filter((item) => item.status === 'fulfilled')
.map<DateItemType>((item: any) => item.value);
// 插入记录
await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
source: item.source,
userId,
kbId,
mode,
prompt
}))
);
insertData.length > 0 && startQueue();
return {
insertLen: insertData.length
};
}
export const config = {
api: {
bodyParser: {
sizeLimit: '20mb'
}
}
};

View File

@@ -0,0 +1,53 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import { withNextCors } from '@/service/utils/tools';
import { openaiEmbedding } from '../plugin/openaiEmbedding';
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { dataId, a = '', q = '' } = req.body as { dataId: string; a?: string; q?: string };
if (!dataId) {
throw new Error('缺少参数');
}
// 凭证校验
const { userId } = await authUser({ req });
// get vector
const vector = await (async () => {
if (q) {
return openaiEmbedding({
userId,
input: [q],
type: 'chat'
});
}
return [];
})();
// 更新 pg 内容.仅修改a不需要更新向量。
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'source', value: '手动修改' },
{ key: 'a', value: a.replace(/'/g, '"') },
...(q
? [
{ key: 'q', value: q.replace(/'/g, '"') },
{ key: 'vector', value: `[${vector[0]}]` }
]
: [])
]
});
jsonRes(res);
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
});