perf: kb-add last question to search

This commit is contained in:
archer
2023-05-03 18:38:59 +08:00
parent e384893ae0
commit 17a42ac0cc
7 changed files with 81 additions and 42 deletions

View File

@@ -1,5 +1,6 @@
### Fast GPT V3.1 ### Fast GPT V3.1
- 优化 - 知识库搜索,会将上一个问题并入搜索范围。
- 优化 - 模型结构设计,不再区分知识库和对话模型,而是通过开关的形式,手动选择手否需要进行知识库搜索。 - 优化 - 模型结构设计,不再区分知识库和对话模型,而是通过开关的形式,手动选择手否需要进行知识库搜索。
- 新增 - 模型共享市场,可以使用其他用户分享的模型。 - 新增 - 模型共享市场,可以使用其他用户分享的模型。
- 新增 - 邀请好友注册功能。 - 新增 - 邀请好友注册功能。

View File

@@ -58,7 +58,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompt } = await searchKb({ const { code, searchPrompt } = await searchKb({
userApiKey, userApiKey,
systemApiKey, systemApiKey,
text: prompt.value, prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model, model,
userId userId

View File

@@ -66,7 +66,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompt } = await searchKb({ const { code, searchPrompt } = await searchKb({
systemApiKey: apiKey, systemApiKey: apiKey,
text: prompts[prompts.length - 1].value, prompts,
similarity, similarity,
model, model,
userId userId

View File

@@ -118,7 +118,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { searchPrompt } = await searchKb({ const { searchPrompt } = await searchKb({
systemApiKey: apiKey, systemApiKey: apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
text: prompt.value, prompts,
model, model,
userId userId
}); });

View File

@@ -60,8 +60,8 @@ export async function generateVector(next = false): Promise<any> {
} }
// 生成词向量 // 生成词向量
const { vector } = await openaiCreateEmbedding({ const { vectors } = await openaiCreateEmbedding({
text: dataItem.q, textArr: [dataItem.q],
userId: dataItem.userId, userId: dataItem.userId,
userApiKey, userApiKey,
systemApiKey systemApiKey
@@ -70,7 +70,7 @@ export async function generateVector(next = false): Promise<any> {
// 更新 pg 向量和状态数据 // 更新 pg 向量和状态数据
await PgClient.update('modelData', { await PgClient.update('modelData', {
values: [ values: [
{ key: 'vector', value: `[${vector}]` }, { key: 'vector', value: `[${vectors[0]}]` },
{ key: 'status', value: `ready` } { key: 'status', value: `ready` }
], ],
where: [['id', dataId]] where: [['id', dataId]]

View File

@@ -4,6 +4,7 @@ import { ModelSchema } from '@/types/mongoSchema';
import { openaiCreateEmbedding } from '../utils/chat/openai'; import { openaiCreateEmbedding } from '../utils/chat/openai';
import { ChatRoleEnum } from '@/constants/chat'; import { ChatRoleEnum } from '@/constants/chat';
import { sliceTextByToken } from '@/utils/chat'; import { sliceTextByToken } from '@/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
/** /**
* use openai embedding search kb * use openai embedding search kb
@@ -11,14 +12,14 @@ import { sliceTextByToken } from '@/utils/chat';
export const searchKb = async ({ export const searchKb = async ({
userApiKey, userApiKey,
systemApiKey, systemApiKey,
text, prompts,
similarity = 0.2, similarity = 0.2,
model, model,
userId userId
}: { }: {
userApiKey?: string; userApiKey?: string;
systemApiKey: string; systemApiKey: string;
text: string; prompts: ChatItemSimpleType[];
model: ModelSchema; model: ModelSchema;
userId: string; userId: string;
similarity?: number; similarity?: number;
@@ -29,30 +30,56 @@ export const searchKb = async ({
value: string; value: string;
}; };
}> => { }> => {
async function search(textArr: string[] = []) {
// 获取提示词的向量
const { vectors: promptVectors } = await openaiCreateEmbedding({
userApiKey,
systemApiKey,
userId,
textArr
});
const searchRes = await Promise.all(
promptVectors.map((promptVector) =>
PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
}).then((res) => res.rows)
)
);
// Remove repeat record
const idSet = new Set<string>();
const filterSearch = searchRes.map((search) =>
search.filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
})
);
return filterSearch.map((item) => item.map((item) => `${item.q}\n${item.a}`).join('\n'));
}
const modelConstantsData = ChatModelMap[model.chat.chatModel]; const modelConstantsData = ChatModelMap[model.chat.chatModel];
// 获取提示词的向量 // search three times
const { vector: promptVector } = await openaiCreateEmbedding({ const userPrompts = prompts.filter((item) => item.obj === 'Human');
userApiKey,
systemApiKey,
userId,
text
});
const vectorSearch = await PgClient.select<{ q: string; a: string }>('modelData', { const searchArr: string[] = [
fields: ['q', 'a'], userPrompts[userPrompts.length - 1].value,
where: [ userPrompts[userPrompts.length - 2]?.value
['status', ModelDataStatusEnum.ready], ].filter((item) => item);
'AND', const systemPrompts = await search(searchArr);
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// filter system prompt // filter system prompt
if ( if (
@@ -80,13 +107,24 @@ export const searchKb = async ({
}; };
} }
// 有匹配情况下system 添加知识库内容。 /* 有匹配情况下system 添加知识库内容。 */
// 系统提示词过滤,最多 65% tokens
const filterSystemPrompt = sliceTextByToken({ // filter system prompts. max 70% tokens
model: model.chat.chatModel, const filterRateMap: Record<number, number[]> = {
text: systemPrompts.join('\n'), 1: [0.7],
length: Math.floor(modelConstantsData.contextMaxToken * 0.65) 2: [0.5, 0.2]
}); };
const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0];
const filterSystemPrompt = filterRate
.map((rate, i) =>
sliceTextByToken({
model: model.chat.chatModel,
text: systemPrompts[i],
length: Math.floor(modelConstantsData.contextMaxToken * rate)
})
)
.join('\n');
return { return {
code: 200, code: 200,

View File

@@ -22,12 +22,12 @@ export const openaiCreateEmbedding = async ({
userApiKey, userApiKey,
systemApiKey, systemApiKey,
userId, userId,
text textArr
}: { }: {
userApiKey?: string; userApiKey?: string;
systemApiKey: string; systemApiKey: string;
userId: string; userId: string;
text: string; textArr: string[];
}) => { }) => {
// 获取 chatAPI // 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemApiKey); const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
@@ -37,7 +37,7 @@ export const openaiCreateEmbedding = async ({
.createEmbedding( .createEmbedding(
{ {
model: embeddingModel, model: embeddingModel,
input: text input: textArr
}, },
{ {
timeout: 60000, timeout: 60000,
@@ -46,18 +46,18 @@ export const openaiCreateEmbedding = async ({
) )
.then((res) => ({ .then((res) => ({
tokenLen: res.data.usage.total_tokens || 0, tokenLen: res.data.usage.total_tokens || 0,
vector: res.data.data?.[0]?.embedding || [] vectors: res.data.data.map((item) => item.embedding)
})); }));
pushGenerateVectorBill({ pushGenerateVectorBill({
isPay: !userApiKey, isPay: !userApiKey,
userId, userId,
text, text: textArr.join(''),
tokenLen: res.tokenLen tokenLen: res.tokenLen
}); });
return { return {
vector: res.vector, vectors: res.vectors,
chatAPI chatAPI
}; };
}; };