perf: query extension code

This commit is contained in:
archer
2025-07-22 17:58:04 +08:00
parent 6d66d0626d
commit 68c2fcd713
7 changed files with 202 additions and 173 deletions

View File

@@ -67,7 +67,7 @@ export const insertDatasetDataVector = async ({
return retryFn(async () => {
const { vectors, tokens } = await getVectorsByText({
model,
input: query,
input: [query],
type: 'db'
});
const { insertId } = await Vector.insert({

View File

@@ -6,14 +6,14 @@ import { addLog } from '../../../common/system/log';
type GetVectorProps = {
model: EmbeddingModelItemType;
input: string;
input: string[];
type?: `${EmbeddingTypeEnm}`;
headers?: Record<string, string>;
};
// text to vector
export async function getVectorsByText({ model, input, type, headers }: GetVectorProps) {
if (!input) {
if (!input || input.length === 0) {
return Promise.reject({
code: 500,
message: 'input is empty'
@@ -31,7 +31,7 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto
...(type === EmbeddingTypeEnm.db && model.dbConfig),
...(type === EmbeddingTypeEnm.query && model.queryConfig),
model: model.model,
input: [input]
input
},
model.requestUrl
? {
@@ -55,7 +55,12 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto
}
const [tokens, vectors] = await Promise.all([
countPromptTokens(input),
(async () => {
if (res.usage) return res.usage.total_tokens;
const tokens = await Promise.all(input.map((item) => countPromptTokens(item)));
return tokens.reduce((sum, item) => sum + item, 0);
})(),
Promise.all(
res.data
.map((item) => unityDimensional(item.embedding))

View File

@@ -9,155 +9,170 @@ import { llmCompletionsBodyFormat, formatLLMResponse } from '../utils';
import { addLog } from '../../../common/system/log';
import { filterGPTMessageByMaxContext } from '../../chat/utils';
import json5 from 'json5';
import type { EmbeddingModelItemType } from '@fastgpt/global/core/ai/model.d';
/*
Query Extension - Semantic Search Enhancement
This module can eliminate referential ambiguity and expand queries based on context to improve retrieval.
Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination
@https://github.com/jina-ai/submodular-optimization/blob/main/submodular_optimization.js
Query Extension - Semantic Search Enhancement
This module can eliminate referential ambiguity and expand queries based on context to improve retrieval.
Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination
*/
async function queriesFilter({
queries,
embeddingModelData
}: {
queries: string[];
embeddingModelData: EmbeddingModelItemType;
}): Promise<{
tokens: number;
queries: string[];
}> {
if (queries.length < 5) {
return {
queries,
tokens: 0
};
}
// Priority Queue implementation for submodular optimization
class PriorityQueue<T> {
private heap: Array<{ item: T; priority: number }> = [];
// Priority Queue implementation for submodular optimization
class PriorityQueue<T> {
private heap: Array<{ item: T; priority: number }> = [];
enqueue(item: T, priority: number): void {
this.heap.push({ item, priority });
this.heap.sort((a, b) => b.priority - a.priority);
}
enqueue(item: T, priority: number): void {
this.heap.push({ item, priority });
this.heap.sort((a, b) => b.priority - a.priority);
dequeue(): T | undefined {
return this.heap.shift()?.item;
}
isEmpty(): boolean {
return this.heap.length === 0;
}
size(): number {
return this.heap.length;
}
}
dequeue(): T | undefined {
return this.heap.shift()?.item;
// Calculate cosine similarity
function cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error('Vectors must have the same length');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA === 0 || normB === 0) return 0;
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
isEmpty(): boolean {
return this.heap.length === 0;
// Calculate marginal gain
function computeMarginalGain(
candidateEmbedding: number[],
selectedEmbeddings: number[][],
originalEmbedding: number[],
alpha: number = 0.3
): number {
if (selectedEmbeddings.length === 0) {
return alpha * cosineSimilarity(originalEmbedding, candidateEmbedding);
}
let maxSimilarity = 0;
for (const selectedEmbedding of selectedEmbeddings) {
const similarity = cosineSimilarity(candidateEmbedding, selectedEmbedding);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
const relevance = alpha * cosineSimilarity(originalEmbedding, candidateEmbedding);
const diversity = 1 - maxSimilarity;
return relevance + diversity;
}
size(): number {
return this.heap.length;
}
}
// Lazy greedy query selection algorithm
function lazyGreedyQuerySelection(
candidates: string[],
embeddings: number[][],
originalEmbedding: number[],
k: number,
alpha: number = 0.3
): string[] {
const n = candidates.length;
const selected: string[] = [];
const selectedEmbeddings: number[][] = [];
// Calculate cosine similarity
function cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error('Vectors must have the same length');
}
// Initialize priority queue
const pq = new PriorityQueue<{ index: number; gain: number }>();
let dotProduct = 0;
let normA = 0;
let normB = 0;
// Calculate initial marginal gain for all candidates
for (let i = 0; i < n; i++) {
const gain = computeMarginalGain(embeddings[i], selectedEmbeddings, originalEmbedding, alpha);
pq.enqueue({ index: i, gain }, gain);
}
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
// Greedy selection
for (let iteration = 0; iteration < k; iteration++) {
if (pq.isEmpty()) break;
if (normA === 0 || normB === 0) return 0;
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
let bestCandidate: { index: number; gain: number } | undefined;
// Calculate marginal gain
function computeMarginalGain(
candidateEmbedding: number[],
selectedEmbeddings: number[][],
originalEmbedding: number[],
alpha: number = 0.3
): number {
if (selectedEmbeddings.length === 0) {
return alpha * cosineSimilarity(originalEmbedding, candidateEmbedding);
}
// Find candidate with maximum marginal gain
while (!pq.isEmpty()) {
const candidate = pq.dequeue()!;
const currentGain = computeMarginalGain(
embeddings[candidate.index],
selectedEmbeddings,
originalEmbedding,
alpha
);
let maxSimilarity = 0;
for (const selectedEmbedding of selectedEmbeddings) {
const similarity = cosineSimilarity(candidateEmbedding, selectedEmbedding);
maxSimilarity = Math.max(maxSimilarity, similarity);
}
if (currentGain >= candidate.gain) {
bestCandidate = { index: candidate.index, gain: currentGain };
break;
} else {
pq.enqueue(candidate, currentGain);
}
}
const relevance = alpha * cosineSimilarity(originalEmbedding, candidateEmbedding);
const diversity = 1 - maxSimilarity;
return relevance + diversity;
}
// Lazy greedy query selection algorithm
function lazyGreedyQuerySelection(
candidates: string[],
embeddings: number[][],
originalEmbedding: number[],
k: number,
alpha: number = 0.3
): string[] {
const n = candidates.length;
const selected: string[] = [];
const selectedEmbeddings: number[][] = [];
// Initialize priority queue
const pq = new PriorityQueue<{ index: number; gain: number }>();
// Calculate initial marginal gain for all candidates
for (let i = 0; i < n; i++) {
const gain = computeMarginalGain(embeddings[i], selectedEmbeddings, originalEmbedding, alpha);
pq.enqueue({ index: i, gain }, gain);
}
// Greedy selection
for (let iteration = 0; iteration < k; iteration++) {
if (pq.isEmpty()) break;
let bestCandidate: { index: number; gain: number } | undefined;
// Find candidate with maximum marginal gain
while (!pq.isEmpty()) {
const candidate = pq.dequeue()!;
const currentGain = computeMarginalGain(
embeddings[candidate.index],
selectedEmbeddings,
originalEmbedding,
alpha
);
if (currentGain >= candidate.gain) {
bestCandidate = { index: candidate.index, gain: currentGain };
break;
} else {
pq.enqueue(candidate, currentGain);
if (bestCandidate) {
selected.push(candidates[bestCandidate.index]);
selectedEmbeddings.push(embeddings[bestCandidate.index]);
}
}
if (bestCandidate) {
selected.push(candidates[bestCandidate.index]);
selectedEmbeddings.push(embeddings[bestCandidate.index]);
}
return selected;
}
return selected;
}
const { vectors, tokens } = await getVectorsByText({
model: embeddingModelData,
input: queries,
type: 'query'
});
// Generate embeddings for input texts
async function generateEmbeddings(texts: string[], model: string): Promise<number[][]> {
try {
const vectorModel = getEmbeddingModel(model);
const embeddings: number[][] = [];
const originalEmbedding = vectors[0];
const candidateEmbeddings = vectors.slice(1);
// Select optimal queries using lazy greedy algorithm
const selectedQueries = lazyGreedyQuerySelection(
queries,
candidateEmbeddings,
originalEmbedding,
Math.min(5, queries.length), // Select top 5 queries or less
0.3 // alpha parameter for balancing relevance and diversity
);
for (const text of texts) {
// Use vector model's createEmbedding method
const embedding = await getVectorsByText({
model: vectorModel,
input: text,
type: 'query'
});
embeddings.push(embedding.vectors[0]);
}
return embeddings;
} catch (error) {
addLog.warn('Failed to generate embeddings', { error, model });
throw error;
}
return {
queries: selectedQueries,
tokens
};
}
const title = global.feConfigs?.systemTitle || 'FastAI';
@@ -254,7 +269,7 @@ assistant: Laf 是一个云函数开发平台。
1. 输出格式为 JSON 数组,数组中每个元素为字符串。无需对输出进行任何解释。
2. 输出语言与原问题相同。原问题为中文则输出中文;原问题为英文则输出英文。
3. 确保生成恰好 {{count}} 个检索词。
3. 确保生成恰好 10 个检索词。
## 开始任务
@@ -269,20 +284,22 @@ export const queryExtension = async ({
chatBg,
query,
histories = [],
model,
generateCount = 10 // 添加生成数量参数默认为10个
llmModel,
embeddingModel
}: {
chatBg?: string;
query: string;
histories: ChatItemType[];
model: string;
generateCount?: number;
llmModel: string;
embeddingModel: string;
}): Promise<{
rawQuery: string;
extensionQueries: string[];
model: string;
llmModel: string;
inputTokens: number;
outputTokens: number;
embeddingTokens: number;
embeddingModel: string;
}> => {
const systemFewShot = chatBg
? `user: 对话背景。
@@ -290,10 +307,12 @@ assistant: ${chatBg}
`
: '';
const modelData = getLLMModel(model);
const llmModelData = getLLMModel(llmModel);
const embeddingModelData = getEmbeddingModel(embeddingModel);
const filterHistories = await filterGPTMessageByMaxContext({
messages: chats2GPTMessages({ messages: histories, reserveId: false }),
maxContext: modelData.maxContext - 1000
maxContext: llmModelData.maxContext - 1000
});
const historyFewShot = filterHistories
@@ -317,8 +336,7 @@ assistant: ${chatBg}
role: 'user',
content: replaceVariable(defaultPrompt, {
query: `${query}`,
histories: concatFewShot || 'null',
count: generateCount.toString()
histories: concatFewShot || 'null'
})
}
] as any;
@@ -327,11 +345,11 @@ assistant: ${chatBg}
body: llmCompletionsBodyFormat(
{
stream: true,
model: modelData.model,
model: llmModelData.model,
temperature: 0.1,
messages
},
modelData
llmModelData
)
});
const { text: answer, usage } = await formatLLMResponse(response);
@@ -342,9 +360,11 @@ assistant: ${chatBg}
return {
rawQuery: query,
extensionQueries: [],
model,
llmModel: llmModelData.model,
inputTokens: inputTokens,
outputTokens: outputTokens
outputTokens: outputTokens,
embeddingModel: embeddingModelData.model,
embeddingTokens: 0
};
}
@@ -357,9 +377,11 @@ assistant: ${chatBg}
return {
rawQuery: query,
extensionQueries: [],
model,
llmModel: llmModelData.model,
inputTokens: inputTokens,
outputTokens: outputTokens
outputTokens: outputTokens,
embeddingModel: embeddingModelData.model,
embeddingTokens: 0
};
}
@@ -376,44 +398,41 @@ assistant: ${chatBg}
return {
rawQuery: query,
extensionQueries: [],
model,
llmModel: llmModelData.model,
inputTokens,
outputTokens
outputTokens,
embeddingModel: embeddingModelData.model,
embeddingTokens: 0
};
}
// Generate embeddings for original query and candidate queries
const allQueries = [query, ...queries];
const embeddings = await generateEmbeddings(allQueries, model);
const originalEmbedding = embeddings[0];
const candidateEmbeddings = embeddings.slice(1);
// Select optimal queries using lazy greedy algorithm
const selectedQueries = lazyGreedyQuerySelection(
queries,
candidateEmbeddings,
originalEmbedding,
Math.min(5, queries.length), // Select top 5 queries or less
0.3 // alpha parameter for balancing relevance and diversity
);
// Filter query
const { queries: filteredQueries, tokens: embeddingTokens } = await queriesFilter({
queries: [query, ...queries].filter(Boolean),
embeddingModelData
});
console.log(filteredQueries, 111);
return {
rawQuery: query,
extensionQueries: selectedQueries,
model,
extensionQueries: filteredQueries,
llmModel: llmModelData.model,
inputTokens,
outputTokens
outputTokens,
embeddingModel: embeddingModelData.model,
embeddingTokens
};
} catch (error) {
addLog.warn('Query extension failed', {
error,
answer
});
addLog.error(`Query extension failed, answer: ${answer}`, error);
return {
rawQuery: query,
extensionQueries: [],
model,
llmModel: llmModelData.model,
inputTokens,
outputTokens
outputTokens,
embeddingModel: embeddingModelData.model,
embeddingTokens: 0
};
}
};

View File

@@ -447,7 +447,7 @@ export async function searchDatasetData(
}) => {
const { vectors, tokens } = await getVectorsByText({
model: getEmbeddingModel(model),
input: query,
input: [query],
type: 'query'
});
@@ -885,6 +885,7 @@ export const defaultSearchDatasetData = async ({
const { concatQueries, extensionQueries, rewriteQuery, aiExtensionResult } =
await datasetSearchQueryExtension({
query,
embeddingModel: props.model,
extensionModel,
extensionBg: datasetSearchExtensionBg,
histories
@@ -898,9 +899,10 @@ export const defaultSearchDatasetData = async ({
return {
...result,
embeddingTokens: result.embeddingTokens + (aiExtensionResult?.embeddingTokens || 0),
queryExtensionResult: aiExtensionResult
? {
model: aiExtensionResult.model,
model: aiExtensionResult.llmModel,
inputTokens: aiExtensionResult.inputTokens,
outputTokens: aiExtensionResult.outputTokens,
query: extensionQueries.join('\n')

View File

@@ -6,11 +6,13 @@ import { chatValue2RuntimePrompt } from '@fastgpt/global/core/chat/adapt';
export const datasetSearchQueryExtension = async ({
query,
embeddingModel,
extensionModel,
extensionBg = '',
histories = []
}: {
query: string;
embeddingModel: string;
extensionModel?: LLMModelItemType;
extensionBg?: string;
histories?: ChatItemType[];
@@ -67,7 +69,8 @@ Human: ${query}
chatBg: extensionBg,
query,
histories,
model: extensionModel.model
llmModel: extensionModel.model,
embeddingModel
});
if (result.extensionQueries?.length === 0) return;
return result;

View File

@@ -102,7 +102,7 @@ const testEmbeddingModel = async (
headers: Record<string, string>
) => {
return getVectorsByText({
input: 'Hi',
input: ['Hi'],
model,
headers
});

View File

@@ -35,7 +35,7 @@ async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
await checkTeamAIPoints(teamId);
const { tokens, vectors } = await getVectorsByText({
input: query,
input: [query],
model: getEmbeddingModel(model),
type
});