diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 9855803d8..dffee774b 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -57,61 +57,28 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 使用了知识库搜索 if (model.chat.useKb) { - const { systemPrompts } = await searchKb_openai({ + const { code, searchPrompt } = await searchKb_openai({ apiKey: userApiKey || systemKey, isPay: !userApiKey, text: prompt.value, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, - modelId, + model, userId }); - // filter system prompt - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity - ) { - return res.send('对不起,你的问题不在知识库中。'); + // search result is empty + if (code === 201) { + return res.send(searchPrompt?.value); } - /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.noContext - ) { - prompts.unshift({ - obj: 'SYSTEM', - value: model.chat.systemPrompt - }); - } else { - // 有匹配情况下,system 添加知识库内容。 - // 系统提示词过滤,最多 2500 tokens - const filterSystemPrompt = systemPromptFilter({ - model: model.chat.chatModel, - prompts: systemPrompts, - maxTokens: 2500 - }); - prompts.unshift({ - obj: 'SYSTEM', - value: ` - ${model.chat.systemPrompt} - ${ - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity - ? `不回答知识库外的内容.` - : '' - } - 知识库内容为: ${filterSystemPrompt}' - ` - }); - } + searchPrompt && prompts.unshift(searchPrompt); } else { // 没有用知识库搜索,仅用系统提示词 - if (model.chat.systemPrompt) { + model.chat.systemPrompt && prompts.unshift({ obj: 'SYSTEM', value: model.chat.systemPrompt }); - } } // 控制总 tokens 数量,防止超出 diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index 38c3f770f..4c17c8434 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -67,57 +67,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (model.chat.useKb) { const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22; - const { systemPrompts } = await searchKb_openai({ + const { code, searchPrompt } = await searchKb_openai({ apiKey, isPay: true, text: prompts[prompts.length - 1].value, similarity, - modelId, + model, userId }); - // filter system prompt - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity - ) { - return jsonRes(res, { - code: 500, - message: '对不起,你的问题不在知识库中。', - data: '对不起,你的问题不在知识库中。' - }); + // search result is empty + if (code === 201) { + return res.send(searchPrompt?.value); } - /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.noContext - ) { - prompts.unshift({ - obj: 'SYSTEM', - value: model.chat.systemPrompt - }); - } else { - // 有匹配情况下,system 添加知识库内容。 - // 系统提示词过滤,最多 2500 tokens - const filterSystemPrompt = systemPromptFilter({ - model: model.chat.chatModel, - prompts: systemPrompts, - maxTokens: 2500 - }); - prompts.unshift({ - obj: 'SYSTEM', - value: ` - ${model.chat.systemPrompt} - ${ - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity - ? `不回答知识库外的内容.` - : '' - } - 知识库内容为: ${filterSystemPrompt}' - ` - }); - } + searchPrompt && prompts.unshift(searchPrompt); } else { // 没有用知识库搜索,仅用系统提示词 if (model.chat.systemPrompt) { diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 00d6b1a3d..711b2e7fc 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -131,26 +131,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const prompts = [prompt]; // 获取向量匹配到的提示词 - const { systemPrompts } = await searchKb_openai({ + const { searchPrompt } = await searchKb_openai({ isPay: true, apiKey, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, text: prompt.value, - modelId, + model, userId }); - // system 筛选,最多 2500 tokens - const filterSystemPrompt = systemPromptFilter({ - model: model.chat.chatModel, - prompts: systemPrompts, - maxTokens: 2500 - }); - - prompts.unshift({ - obj: 'SYSTEM', - value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}` - }); + searchPrompt && prompts.unshift(searchPrompt); // 控制上下文 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ @@ -181,8 +171,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } ); - console.log('code response. time:', `${(Date.now() - startTime) / 1000}s`); - let responseContent = ''; if (isStream) { diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts index 46f3e86d1..20cdd4983 100644 --- a/src/pages/api/openapi/chat/vectorGpt.ts +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -68,60 +68,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 获取向量匹配到的提示词 - const { systemPrompts } = await searchKb_openai({ + const { code, searchPrompt } = await searchKb_openai({ isPay: true, apiKey, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, text: prompts[prompts.length - 1].value, - modelId, + model, userId }); - // system 合并 - if (prompts[0].obj === 'SYSTEM') { - systemPrompts.unshift(prompts.shift()?.value || ''); + // search result is empty + if (code === 201) { + return res.send(searchPrompt?.value); } - /* 高相似度+退出,无法匹配时直接退出 */ - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity - ) { - return jsonRes(res, { - code: 500, - message: '对不起,你的问题不在知识库中。', - data: '对不起,你的问题不在知识库中。' - }); - } - /* 高相似度+无上下文,不添加额外知识 */ - if ( - systemPrompts.length === 0 && - model.chat.searchMode === ModelVectorSearchModeEnum.noContext - ) { - prompts.unshift({ - obj: 'SYSTEM', - value: model.chat.systemPrompt - }); - } else { - // 有匹配或者低匹配度模式情况下,添加知识库内容。 - // 系统提示词过滤,最多 2500 tokens - const systemPrompt = systemPromptFilter({ - model: model.chat.chatModel, - prompts: systemPrompts, - maxTokens: 2500 - }); - - prompts.unshift({ - obj: 'SYSTEM', - value: ` -${model.chat.systemPrompt} -${ - model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : '' -} -知识库内容为: ${systemPrompt}' -` - }); - } + searchPrompt && prompts.unshift(searchPrompt); // 控制在 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ diff --git a/src/service/tools/searchKb.ts b/src/service/tools/searchKb.ts index d513de0a6..464fb1642 100644 --- a/src/service/tools/searchKb.ts +++ b/src/service/tools/searchKb.ts @@ -1,6 +1,8 @@ import { openaiCreateEmbedding } from '../utils/openai'; import { PgClient } from '@/service/pg'; -import { ModelDataStatusEnum } from '@/constants/model'; +import { ModelDataStatusEnum, ModelVectorSearchModeEnum } from '@/constants/model'; +import { ModelSchema } from '@/types/mongoSchema'; +import { systemPromptFilter } from '../utils/tools'; /** * use openai embedding search kb @@ -10,16 +12,22 @@ export const searchKb_openai = async ({ isPay, text, similarity, - modelId, + model, userId }: { apiKey: string; isPay: boolean; text: string; - modelId: string; + model: ModelSchema; userId: string; similarity: number; -}) => { +}): Promise<{ + code: 200 | 201; + searchPrompt?: { + obj: 'Human' | 'AI' | 'SYSTEM'; + value: string; + }; +}> => { // 获取提示词的向量 const { vector: promptVector } = await openaiCreateEmbedding({ isPay, @@ -28,12 +36,12 @@ export const searchKb_openai = async ({ text }); - const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { - fields: ['id', 'q', 'a'], + const vectorSearch = await PgClient.select<{ q: string; a: string }>('modelData', { + fields: ['q', 'a'], where: [ ['status', ModelDataStatusEnum.ready], 'AND', - ['model_id', modelId], + ['model_id', model._id], 'AND', `vector <=> '[${promptVector}]' < ${similarity}` ], @@ -43,5 +51,51 @@ export const searchKb_openai = async ({ const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); - return { systemPrompts }; + // filter system prompt + if ( + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity + ) { + return { + code: 201, + searchPrompt: { + obj: 'AI', + value: '对不起,你的问题不在知识库中。' + } + }; + } + /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ + if (systemPrompts.length === 0 && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) { + return { + code: 200, + searchPrompt: model.chat.systemPrompt + ? { + obj: 'SYSTEM', + value: model.chat.systemPrompt + } + : undefined + }; + } + + // 有匹配情况下,system 添加知识库内容。 + // 系统提示词过滤,最多 2500 tokens + const filterSystemPrompt = systemPromptFilter({ + model: model.chat.chatModel, + prompts: systemPrompts, + maxTokens: 2500 + }); + + return { + code: 200, + searchPrompt: { + obj: 'SYSTEM', + value: ` +${model.chat.systemPrompt} +${ + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : '' +} +知识库内容为: ${filterSystemPrompt}' +` + } + }; };