diff --git a/packages/service/common/vectorDB/controller.ts b/packages/service/common/vectorDB/controller.ts index 6c25ea3cc..044551746 100644 --- a/packages/service/common/vectorDB/controller.ts +++ b/packages/service/common/vectorDB/controller.ts @@ -67,7 +67,7 @@ export const insertDatasetDataVector = async ({ return retryFn(async () => { const { vectors, tokens } = await getVectorsByText({ model, - input: query, + input: [query], type: 'db' }); const { insertId } = await Vector.insert({ diff --git a/packages/service/core/ai/embedding/index.ts b/packages/service/core/ai/embedding/index.ts index 45599c173..80e095c9f 100644 --- a/packages/service/core/ai/embedding/index.ts +++ b/packages/service/core/ai/embedding/index.ts @@ -6,14 +6,14 @@ import { addLog } from '../../../common/system/log'; type GetVectorProps = { model: EmbeddingModelItemType; - input: string; + input: string[]; type?: `${EmbeddingTypeEnm}`; headers?: Record; }; // text to vector export async function getVectorsByText({ model, input, type, headers }: GetVectorProps) { - if (!input) { + if (!input || input.length === 0) { return Promise.reject({ code: 500, message: 'input is empty' @@ -31,7 +31,7 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto ...(type === EmbeddingTypeEnm.db && model.dbConfig), ...(type === EmbeddingTypeEnm.query && model.queryConfig), model: model.model, - input: [input] + input }, model.requestUrl ? { @@ -55,7 +55,12 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto } const [tokens, vectors] = await Promise.all([ - countPromptTokens(input), + (async () => { + if (res.usage) return res.usage.total_tokens; + + const tokens = await Promise.all(input.map((item) => countPromptTokens(item))); + return tokens.reduce((sum, item) => sum + item, 0); + })(), Promise.all( res.data .map((item) => unityDimensional(item.embedding)) diff --git a/packages/service/core/ai/functions/queryExtension.ts b/packages/service/core/ai/functions/queryExtension.ts index 812b31cca..7263c53a9 100644 --- a/packages/service/core/ai/functions/queryExtension.ts +++ b/packages/service/core/ai/functions/queryExtension.ts @@ -9,155 +9,170 @@ import { llmCompletionsBodyFormat, formatLLMResponse } from '../utils'; import { addLog } from '../../../common/system/log'; import { filterGPTMessageByMaxContext } from '../../chat/utils'; import json5 from 'json5'; +import type { EmbeddingModelItemType } from '@fastgpt/global/core/ai/model.d'; /* - Query Extension - Semantic Search Enhancement - - This module can eliminate referential ambiguity and expand queries based on context to improve retrieval. - - Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination + @https://github.com/jina-ai/submodular-optimization/blob/main/submodular_optimization.js + Query Extension - Semantic Search Enhancement + This module can eliminate referential ambiguity and expand queries based on context to improve retrieval. + Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination */ +async function queriesFilter({ + queries, + embeddingModelData +}: { + queries: string[]; + embeddingModelData: EmbeddingModelItemType; +}): Promise<{ + tokens: number; + queries: string[]; +}> { + if (queries.length < 5) { + return { + queries, + tokens: 0 + }; + } + // Priority Queue implementation for submodular optimization + class PriorityQueue { + private heap: Array<{ item: T; priority: number }> = []; -// Priority Queue implementation for submodular optimization -class PriorityQueue { - private heap: Array<{ item: T; priority: number }> = []; + enqueue(item: T, priority: number): void { + this.heap.push({ item, priority }); + this.heap.sort((a, b) => b.priority - a.priority); + } - enqueue(item: T, priority: number): void { - this.heap.push({ item, priority }); - this.heap.sort((a, b) => b.priority - a.priority); + dequeue(): T | undefined { + return this.heap.shift()?.item; + } + + isEmpty(): boolean { + return this.heap.length === 0; + } + + size(): number { + return this.heap.length; + } } - dequeue(): T | undefined { - return this.heap.shift()?.item; + // Calculate cosine similarity + function cosineSimilarity(a: number[], b: number[]): number { + if (a.length !== b.length) { + throw new Error('Vectors must have the same length'); + } + + let dotProduct = 0; + let normA = 0; + let normB = 0; + + for (let i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + + if (normA === 0 || normB === 0) return 0; + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } - isEmpty(): boolean { - return this.heap.length === 0; + // Calculate marginal gain + function computeMarginalGain( + candidateEmbedding: number[], + selectedEmbeddings: number[][], + originalEmbedding: number[], + alpha: number = 0.3 + ): number { + if (selectedEmbeddings.length === 0) { + return alpha * cosineSimilarity(originalEmbedding, candidateEmbedding); + } + + let maxSimilarity = 0; + for (const selectedEmbedding of selectedEmbeddings) { + const similarity = cosineSimilarity(candidateEmbedding, selectedEmbedding); + maxSimilarity = Math.max(maxSimilarity, similarity); + } + + const relevance = alpha * cosineSimilarity(originalEmbedding, candidateEmbedding); + const diversity = 1 - maxSimilarity; + + return relevance + diversity; } - size(): number { - return this.heap.length; - } -} + // Lazy greedy query selection algorithm + function lazyGreedyQuerySelection( + candidates: string[], + embeddings: number[][], + originalEmbedding: number[], + k: number, + alpha: number = 0.3 + ): string[] { + const n = candidates.length; + const selected: string[] = []; + const selectedEmbeddings: number[][] = []; -// Calculate cosine similarity -function cosineSimilarity(a: number[], b: number[]): number { - if (a.length !== b.length) { - throw new Error('Vectors must have the same length'); - } + // Initialize priority queue + const pq = new PriorityQueue<{ index: number; gain: number }>(); - let dotProduct = 0; - let normA = 0; - let normB = 0; + // Calculate initial marginal gain for all candidates + for (let i = 0; i < n; i++) { + const gain = computeMarginalGain(embeddings[i], selectedEmbeddings, originalEmbedding, alpha); + pq.enqueue({ index: i, gain }, gain); + } - for (let i = 0; i < a.length; i++) { - dotProduct += a[i] * b[i]; - normA += a[i] * a[i]; - normB += b[i] * b[i]; - } + // Greedy selection + for (let iteration = 0; iteration < k; iteration++) { + if (pq.isEmpty()) break; - if (normA === 0 || normB === 0) return 0; - return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); -} + let bestCandidate: { index: number; gain: number } | undefined; -// Calculate marginal gain -function computeMarginalGain( - candidateEmbedding: number[], - selectedEmbeddings: number[][], - originalEmbedding: number[], - alpha: number = 0.3 -): number { - if (selectedEmbeddings.length === 0) { - return alpha * cosineSimilarity(originalEmbedding, candidateEmbedding); - } + // Find candidate with maximum marginal gain + while (!pq.isEmpty()) { + const candidate = pq.dequeue()!; + const currentGain = computeMarginalGain( + embeddings[candidate.index], + selectedEmbeddings, + originalEmbedding, + alpha + ); - let maxSimilarity = 0; - for (const selectedEmbedding of selectedEmbeddings) { - const similarity = cosineSimilarity(candidateEmbedding, selectedEmbedding); - maxSimilarity = Math.max(maxSimilarity, similarity); - } + if (currentGain >= candidate.gain) { + bestCandidate = { index: candidate.index, gain: currentGain }; + break; + } else { + pq.enqueue(candidate, currentGain); + } + } - const relevance = alpha * cosineSimilarity(originalEmbedding, candidateEmbedding); - const diversity = 1 - maxSimilarity; - - return relevance + diversity; -} - -// Lazy greedy query selection algorithm -function lazyGreedyQuerySelection( - candidates: string[], - embeddings: number[][], - originalEmbedding: number[], - k: number, - alpha: number = 0.3 -): string[] { - const n = candidates.length; - const selected: string[] = []; - const selectedEmbeddings: number[][] = []; - - // Initialize priority queue - const pq = new PriorityQueue<{ index: number; gain: number }>(); - - // Calculate initial marginal gain for all candidates - for (let i = 0; i < n; i++) { - const gain = computeMarginalGain(embeddings[i], selectedEmbeddings, originalEmbedding, alpha); - pq.enqueue({ index: i, gain }, gain); - } - - // Greedy selection - for (let iteration = 0; iteration < k; iteration++) { - if (pq.isEmpty()) break; - - let bestCandidate: { index: number; gain: number } | undefined; - - // Find candidate with maximum marginal gain - while (!pq.isEmpty()) { - const candidate = pq.dequeue()!; - const currentGain = computeMarginalGain( - embeddings[candidate.index], - selectedEmbeddings, - originalEmbedding, - alpha - ); - - if (currentGain >= candidate.gain) { - bestCandidate = { index: candidate.index, gain: currentGain }; - break; - } else { - pq.enqueue(candidate, currentGain); + if (bestCandidate) { + selected.push(candidates[bestCandidate.index]); + selectedEmbeddings.push(embeddings[bestCandidate.index]); } } - if (bestCandidate) { - selected.push(candidates[bestCandidate.index]); - selectedEmbeddings.push(embeddings[bestCandidate.index]); - } + return selected; } - return selected; -} + const { vectors, tokens } = await getVectorsByText({ + model: embeddingModelData, + input: queries, + type: 'query' + }); -// Generate embeddings for input texts -async function generateEmbeddings(texts: string[], model: string): Promise { - try { - const vectorModel = getEmbeddingModel(model); - const embeddings: number[][] = []; + const originalEmbedding = vectors[0]; + const candidateEmbeddings = vectors.slice(1); + // Select optimal queries using lazy greedy algorithm + const selectedQueries = lazyGreedyQuerySelection( + queries, + candidateEmbeddings, + originalEmbedding, + Math.min(5, queries.length), // Select top 5 queries or less + 0.3 // alpha parameter for balancing relevance and diversity + ); - for (const text of texts) { - // Use vector model's createEmbedding method - const embedding = await getVectorsByText({ - model: vectorModel, - input: text, - type: 'query' - }); - embeddings.push(embedding.vectors[0]); - } - - return embeddings; - } catch (error) { - addLog.warn('Failed to generate embeddings', { error, model }); - throw error; - } + return { + queries: selectedQueries, + tokens + }; } const title = global.feConfigs?.systemTitle || 'FastAI'; @@ -254,7 +269,7 @@ assistant: Laf 是一个云函数开发平台。 1. 输出格式为 JSON 数组,数组中每个元素为字符串。无需对输出进行任何解释。 2. 输出语言与原问题相同。原问题为中文则输出中文;原问题为英文则输出英文。 -3. 确保生成恰好 {{count}} 个检索词。 +3. 确保生成恰好 10 个检索词。 ## 开始任务 @@ -269,20 +284,22 @@ export const queryExtension = async ({ chatBg, query, histories = [], - model, - generateCount = 10 // 添加生成数量参数,默认为10个 + llmModel, + embeddingModel }: { chatBg?: string; query: string; histories: ChatItemType[]; - model: string; - generateCount?: number; + llmModel: string; + embeddingModel: string; }): Promise<{ rawQuery: string; extensionQueries: string[]; - model: string; + llmModel: string; inputTokens: number; outputTokens: number; + embeddingTokens: number; + embeddingModel: string; }> => { const systemFewShot = chatBg ? `user: 对话背景。 @@ -290,10 +307,12 @@ assistant: ${chatBg} ` : ''; - const modelData = getLLMModel(model); + const llmModelData = getLLMModel(llmModel); + const embeddingModelData = getEmbeddingModel(embeddingModel); + const filterHistories = await filterGPTMessageByMaxContext({ messages: chats2GPTMessages({ messages: histories, reserveId: false }), - maxContext: modelData.maxContext - 1000 + maxContext: llmModelData.maxContext - 1000 }); const historyFewShot = filterHistories @@ -317,8 +336,7 @@ assistant: ${chatBg} role: 'user', content: replaceVariable(defaultPrompt, { query: `${query}`, - histories: concatFewShot || 'null', - count: generateCount.toString() + histories: concatFewShot || 'null' }) } ] as any; @@ -327,11 +345,11 @@ assistant: ${chatBg} body: llmCompletionsBodyFormat( { stream: true, - model: modelData.model, + model: llmModelData.model, temperature: 0.1, messages }, - modelData + llmModelData ) }); const { text: answer, usage } = await formatLLMResponse(response); @@ -342,9 +360,11 @@ assistant: ${chatBg} return { rawQuery: query, extensionQueries: [], - model, + llmModel: llmModelData.model, inputTokens: inputTokens, - outputTokens: outputTokens + outputTokens: outputTokens, + embeddingModel: embeddingModelData.model, + embeddingTokens: 0 }; } @@ -357,9 +377,11 @@ assistant: ${chatBg} return { rawQuery: query, extensionQueries: [], - model, + llmModel: llmModelData.model, inputTokens: inputTokens, - outputTokens: outputTokens + outputTokens: outputTokens, + embeddingModel: embeddingModelData.model, + embeddingTokens: 0 }; } @@ -376,44 +398,41 @@ assistant: ${chatBg} return { rawQuery: query, extensionQueries: [], - model, + llmModel: llmModelData.model, inputTokens, - outputTokens + outputTokens, + embeddingModel: embeddingModelData.model, + embeddingTokens: 0 }; } // Generate embeddings for original query and candidate queries - const allQueries = [query, ...queries]; - const embeddings = await generateEmbeddings(allQueries, model); - const originalEmbedding = embeddings[0]; - const candidateEmbeddings = embeddings.slice(1); - // Select optimal queries using lazy greedy algorithm - const selectedQueries = lazyGreedyQuerySelection( - queries, - candidateEmbeddings, - originalEmbedding, - Math.min(5, queries.length), // Select top 5 queries or less - 0.3 // alpha parameter for balancing relevance and diversity - ); + // Filter query + const { queries: filteredQueries, tokens: embeddingTokens } = await queriesFilter({ + queries: [query, ...queries].filter(Boolean), + embeddingModelData + }); + console.log(filteredQueries, 111); return { rawQuery: query, - extensionQueries: selectedQueries, - model, + extensionQueries: filteredQueries, + llmModel: llmModelData.model, inputTokens, - outputTokens + outputTokens, + embeddingModel: embeddingModelData.model, + embeddingTokens }; } catch (error) { - addLog.warn('Query extension failed', { - error, - answer - }); + addLog.error(`Query extension failed, answer: ${answer}`, error); return { rawQuery: query, extensionQueries: [], - model, + llmModel: llmModelData.model, inputTokens, - outputTokens + outputTokens, + embeddingModel: embeddingModelData.model, + embeddingTokens: 0 }; } }; diff --git a/packages/service/core/dataset/search/controller.ts b/packages/service/core/dataset/search/controller.ts index 87ae0747c..297ec4c15 100644 --- a/packages/service/core/dataset/search/controller.ts +++ b/packages/service/core/dataset/search/controller.ts @@ -447,7 +447,7 @@ export async function searchDatasetData( }) => { const { vectors, tokens } = await getVectorsByText({ model: getEmbeddingModel(model), - input: query, + input: [query], type: 'query' }); @@ -885,6 +885,7 @@ export const defaultSearchDatasetData = async ({ const { concatQueries, extensionQueries, rewriteQuery, aiExtensionResult } = await datasetSearchQueryExtension({ query, + embeddingModel: props.model, extensionModel, extensionBg: datasetSearchExtensionBg, histories @@ -898,9 +899,10 @@ export const defaultSearchDatasetData = async ({ return { ...result, + embeddingTokens: result.embeddingTokens + (aiExtensionResult?.embeddingTokens || 0), queryExtensionResult: aiExtensionResult ? { - model: aiExtensionResult.model, + model: aiExtensionResult.llmModel, inputTokens: aiExtensionResult.inputTokens, outputTokens: aiExtensionResult.outputTokens, query: extensionQueries.join('\n') diff --git a/packages/service/core/dataset/search/utils.ts b/packages/service/core/dataset/search/utils.ts index 410409f3d..f5b3a0e95 100644 --- a/packages/service/core/dataset/search/utils.ts +++ b/packages/service/core/dataset/search/utils.ts @@ -6,11 +6,13 @@ import { chatValue2RuntimePrompt } from '@fastgpt/global/core/chat/adapt'; export const datasetSearchQueryExtension = async ({ query, + embeddingModel, extensionModel, extensionBg = '', histories = [] }: { query: string; + embeddingModel: string; extensionModel?: LLMModelItemType; extensionBg?: string; histories?: ChatItemType[]; @@ -67,7 +69,8 @@ Human: ${query} chatBg: extensionBg, query, histories, - model: extensionModel.model + llmModel: extensionModel.model, + embeddingModel }); if (result.extensionQueries?.length === 0) return; return result; diff --git a/projects/app/src/pages/api/core/ai/model/test.ts b/projects/app/src/pages/api/core/ai/model/test.ts index 974319619..67ed9ba1d 100644 --- a/projects/app/src/pages/api/core/ai/model/test.ts +++ b/projects/app/src/pages/api/core/ai/model/test.ts @@ -102,7 +102,7 @@ const testEmbeddingModel = async ( headers: Record ) => { return getVectorsByText({ - input: 'Hi', + input: ['Hi'], model, headers }); diff --git a/projects/app/src/pages/api/v1/embeddings.ts b/projects/app/src/pages/api/v1/embeddings.ts index a6d42bb51..76ffe7b12 100644 --- a/projects/app/src/pages/api/v1/embeddings.ts +++ b/projects/app/src/pages/api/v1/embeddings.ts @@ -35,7 +35,7 @@ async function handler(req: NextApiRequest, res: NextApiResponse) { await checkTeamAIPoints(teamId); const { tokens, vectors } = await getVectorsByText({ - input: query, + input: [query], model: getEmbeddingModel(model), type });