diff --git a/public/docs/versionIntro.md b/public/docs/versionIntro.md index cb0efa751..e513a4854 100644 --- a/public/docs/versionIntro.md +++ b/public/docs/versionIntro.md @@ -1,5 +1,6 @@ ### Fast GPT V3.1 +- 优化 - 知识库搜索,会将上一个问题并入搜索范围。 - 优化 - 模型结构设计,不再区分知识库和对话模型,而是通过开关的形式,手动选择手否需要进行知识库搜索。 - 新增 - 模型共享市场,可以使用其他用户分享的模型。 - 新增 - 邀请好友注册功能。 diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 4094872bc..66bf6a32e 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -58,7 +58,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { code, searchPrompt } = await searchKb({ userApiKey, systemApiKey, - text: prompt.value, + prompts, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, model, userId diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index 5ca7fb0b0..5bbadf836 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -66,7 +66,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { code, searchPrompt } = await searchKb({ systemApiKey: apiKey, - text: prompts[prompts.length - 1].value, + prompts, similarity, model, userId diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 40fbd2822..a58b4518a 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -118,7 +118,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { searchPrompt } = await searchKb({ systemApiKey: apiKey, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, - text: prompt.value, + prompts, model, userId }); diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index e225eaff5..1c661c6b6 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -60,8 +60,8 @@ export async function generateVector(next = false): Promise { } // 生成词向量 - const { vector } = await openaiCreateEmbedding({ - text: dataItem.q, + const { vectors } = await openaiCreateEmbedding({ + textArr: [dataItem.q], userId: dataItem.userId, userApiKey, systemApiKey @@ -70,7 +70,7 @@ export async function generateVector(next = false): Promise { // 更新 pg 向量和状态数据 await PgClient.update('modelData', { values: [ - { key: 'vector', value: `[${vector}]` }, + { key: 'vector', value: `[${vectors[0]}]` }, { key: 'status', value: `ready` } ], where: [['id', dataId]] diff --git a/src/service/plugins/searchKb.ts b/src/service/plugins/searchKb.ts index 1682fab0c..48974ccaf 100644 --- a/src/service/plugins/searchKb.ts +++ b/src/service/plugins/searchKb.ts @@ -4,6 +4,7 @@ import { ModelSchema } from '@/types/mongoSchema'; import { openaiCreateEmbedding } from '../utils/chat/openai'; import { ChatRoleEnum } from '@/constants/chat'; import { sliceTextByToken } from '@/utils/chat'; +import { ChatItemSimpleType } from '@/types/chat'; /** * use openai embedding search kb @@ -11,14 +12,14 @@ import { sliceTextByToken } from '@/utils/chat'; export const searchKb = async ({ userApiKey, systemApiKey, - text, + prompts, similarity = 0.2, model, userId }: { userApiKey?: string; systemApiKey: string; - text: string; + prompts: ChatItemSimpleType[]; model: ModelSchema; userId: string; similarity?: number; @@ -29,30 +30,56 @@ export const searchKb = async ({ value: string; }; }> => { + async function search(textArr: string[] = []) { + // 获取提示词的向量 + const { vectors: promptVectors } = await openaiCreateEmbedding({ + userApiKey, + systemApiKey, + userId, + textArr + }); + + const searchRes = await Promise.all( + promptVectors.map((promptVector) => + PgClient.select<{ id: string; q: string; a: string }>('modelData', { + fields: ['id', 'q', 'a'], + where: [ + ['status', ModelDataStatusEnum.ready], + 'AND', + ['model_id', model._id], + 'AND', + `vector <=> '[${promptVector}]' < ${similarity}` + ], + order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], + limit: 20 + }).then((res) => res.rows) + ) + ); + + // Remove repeat record + const idSet = new Set(); + const filterSearch = searchRes.map((search) => + search.filter((item) => { + if (idSet.has(item.id)) { + return false; + } + idSet.add(item.id); + return true; + }) + ); + + return filterSearch.map((item) => item.map((item) => `${item.q}\n${item.a}`).join('\n')); + } const modelConstantsData = ChatModelMap[model.chat.chatModel]; - // 获取提示词的向量 - const { vector: promptVector } = await openaiCreateEmbedding({ - userApiKey, - systemApiKey, - userId, - text - }); + // search three times + const userPrompts = prompts.filter((item) => item.obj === 'Human'); - const vectorSearch = await PgClient.select<{ q: string; a: string }>('modelData', { - fields: ['q', 'a'], - where: [ - ['status', ModelDataStatusEnum.ready], - 'AND', - ['model_id', model._id], - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: 20 - }); - - const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); + const searchArr: string[] = [ + userPrompts[userPrompts.length - 1].value, + userPrompts[userPrompts.length - 2]?.value + ].filter((item) => item); + const systemPrompts = await search(searchArr); // filter system prompt if ( @@ -80,13 +107,24 @@ export const searchKb = async ({ }; } - // 有匹配情况下,system 添加知识库内容。 - // 系统提示词过滤,最多 65% tokens - const filterSystemPrompt = sliceTextByToken({ - model: model.chat.chatModel, - text: systemPrompts.join('\n'), - length: Math.floor(modelConstantsData.contextMaxToken * 0.65) - }); + /* 有匹配情况下,system 添加知识库内容。 */ + + // filter system prompts. max 70% tokens + const filterRateMap: Record = { + 1: [0.7], + 2: [0.5, 0.2] + }; + const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0]; + + const filterSystemPrompt = filterRate + .map((rate, i) => + sliceTextByToken({ + model: model.chat.chatModel, + text: systemPrompts[i], + length: Math.floor(modelConstantsData.contextMaxToken * rate) + }) + ) + .join('\n'); return { code: 200, diff --git a/src/service/utils/chat/openai.ts b/src/service/utils/chat/openai.ts index db8ad15dc..4ae804505 100644 --- a/src/service/utils/chat/openai.ts +++ b/src/service/utils/chat/openai.ts @@ -22,12 +22,12 @@ export const openaiCreateEmbedding = async ({ userApiKey, systemApiKey, userId, - text + textArr }: { userApiKey?: string; systemApiKey: string; userId: string; - text: string; + textArr: string[]; }) => { // 获取 chatAPI const chatAPI = getOpenAIApi(userApiKey || systemApiKey); @@ -37,7 +37,7 @@ export const openaiCreateEmbedding = async ({ .createEmbedding( { model: embeddingModel, - input: text + input: textArr }, { timeout: 60000, @@ -46,18 +46,18 @@ export const openaiCreateEmbedding = async ({ ) .then((res) => ({ tokenLen: res.data.usage.total_tokens || 0, - vector: res.data.data?.[0]?.embedding || [] + vectors: res.data.data.map((item) => item.embedding) })); pushGenerateVectorBill({ isPay: !userApiKey, userId, - text, + text: textArr.join(''), tokenLen: res.tokenLen }); return { - vector: res.vector, + vectors: res.vectors, chatAPI }; };