@@ -9,155 +9,170 @@ import { llmCompletionsBodyFormat, formatLLMResponse } from '../utils';
import { addLog } from '../../../common/system/log' ;
import { filterGPTMessageByMaxContext } from '../../chat/utils' ;
import json5 from 'json5' ;
import type { EmbeddingModelItemType } from '@fastgpt/global/core/ai/model.d' ;
/*
Query Extension - Semantic Search Enhancement
This module can eliminate referential ambiguity and expand queries based on context to improve retrieval.
Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination
@https://github.com/jina-ai/submodular-optimization/blob/main/submodular_optimization.js
Query Extension - Semantic Search Enhancement
This module can eliminate referential ambiguity and expand queries based on context to improve retrieval.
Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination
*/
async function queriesFilter ( {
queries ,
embeddingModelData
} : {
queries : string [ ] ;
embeddingModelData : EmbeddingModelItemType ;
} ) : Promise < {
tokens : number ;
queries : string [ ] ;
} > {
if ( queries . length < 5 ) {
return {
queries ,
tokens : 0
} ;
}
// Priority Queue implementation for submodular optimization
class PriorityQueue < T > {
private heap : Array < { item : T ; priority : number } > = [ ] ;
// Priority Queue implementation for submodular optimization
class PriorityQueue < T > {
private heap : Array < { item : T ; priority : number } > = [ ] ;
enqueue ( item : T , priority : number ) : void {
this . heap . push ( { item , priority } ) ;
this . heap . sort ( ( a , b ) = > b . priority - a . priority ) ;
}
en queue ( item : T , priority : number ) : voi d {
this . heap . push ( { item , priority } ) ;
this . heap . sort ( ( a , b ) = > b . priority - a . priority ) ;
d equeue( ) : T | undefine d {
return this . heap . shift ( ) ? . item ;
}
isEmpty ( ) : boolean {
return this . heap . length === 0 ;
}
size ( ) : number {
return this . heap . length ;
}
}
dequeue ( ) : T | undefined {
return this . heap . shift ( ) ? . item ;
// Calculate cosine similarity
function cosineSimilarity ( a : number [ ] , b : number [ ] ) : number {
if ( a . length !== b . length ) {
throw new Error ( 'Vectors must have the same length' ) ;
}
let dotProduct = 0 ;
let normA = 0 ;
let normB = 0 ;
for ( let i = 0 ; i < a . length ; i ++ ) {
dotProduct += a [ i ] * b [ i ] ;
normA += a [ i ] * a [ i ] ;
normB += b [ i ] * b [ i ] ;
}
if ( normA === 0 || normB === 0 ) return 0 ;
return dotProduct / ( Math . sqrt ( normA ) * Math . sqrt ( normB ) ) ;
}
isEmpty ( ) : boolean {
return this . heap . length === 0 ;
// Calculate marginal gain
function computeMarginalGain (
candidateEmbedding : number [ ] ,
selectedEmbeddings : number [ ] [ ] ,
originalEmbedding : number [ ] ,
alpha : number = 0.3
) : number {
if ( selectedEmbeddings . length === 0 ) {
return alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
}
let maxSimilarity = 0 ;
for ( const selectedEmbedding of selectedEmbeddings ) {
const similarity = cosineSimilarity ( candidateEmbedding , selectedEmbedding ) ;
maxSimilarity = Math . max ( maxSimilarity , similarity ) ;
}
const relevance = alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
const diversity = 1 - maxSimilarity ;
return relevance + diversity ;
}
size ( ) : number {
return this . heap . length ;
}
}
// Lazy greedy query selection algorithm
function lazyGreedyQuerySelection (
candidates : string [ ] ,
embeddings : number [ ] [ ] ,
originalEmbedding : number [ ] ,
k : number ,
alpha : number = 0.3
) : string [ ] {
const n = candidates . length ;
const selected : string [ ] = [ ] ;
const selectedEmbeddings : number [ ] [ ] = [ ] ;
// Calculate cosine similarity
function cosineSimilarity ( a : number [ ] , b : number [ ] ) : number {
if ( a . length !== b . length ) {
throw new Error ( 'Vectors must have the same length' ) ;
}
// Initialize priority queue
const pq = new PriorityQueue < { index : number ; gain : number } > ( ) ;
let dotProduct = 0 ;
let normA = 0 ;
let normB = 0 ;
// Calculate initial marginal gain for all candidates
for ( let i = 0 ; i < n ; i ++ ) {
const gain = computeMarginalGain ( embeddings [ i ] , selectedEmbeddings , originalEmbedding , alpha ) ;
pq . enqueue ( { index : i , gain } , gain ) ;
}
for ( let i = 0 ; i < a . length ; i ++ ) {
dotProduct + = a [ i ] * b [ i ] ;
normA += a [ i] * a [ i ] ;
normB += b [ i ] * b [ i ] ;
}
// Greedy selection
for ( let iteration = 0 ; iteration < k ; iteration ++ ) {
if ( pq . isEmpty ( ) ) break ;
if ( normA === 0 || normB === 0 ) return 0 ;
return dotProduct / ( Math . sqrt ( normA ) * Math . sqrt ( normB ) ) ;
}
let bestCandidate : { index : number ; gain : number } | undefined ;
// Calculate marginal gain
function computeMarginalGain (
candidateEmbedding : number [ ] ,
selectedEmbeddings : number [ ] [ ] ,
originalEmbedding : number [ ] ,
alpha : number = 0.3
) : number {
if ( selectedEmbeddings . length === 0 ) {
return alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
}
// Find candidate with maximum marginal gain
while ( ! pq . isEmpty ( ) ) {
const candidate = pq . dequeue ( ) ! ;
const currentGain = computeMarginalGain (
embeddings [ candidate . index ] ,
selectedEmbeddings ,
originalEmbedding ,
alpha
) ;
let maxSimilarity = 0 ;
for ( const selectedEmbedding of selectedEmbeddings ) {
const similarity = cosineSimilarity ( candidateEmbedding , selectedEmbedding ) ;
maxSimilarity = Math . max ( maxSimilarity , similarity ) ;
}
if ( currentGain > = candidate . gain ) {
bestCandidate = { index : candidate.index , gain : currentGain } ;
break ;
} else {
pq . enqueue ( candidate , currentGain ) ;
}
}
const relevance = alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
const diversity = 1 - maxSimilarity ;
return relevance + diversity ;
}
// Lazy greedy query selection algorithm
function lazyGreedyQuerySelection (
candidates : string [ ] ,
embeddings : number [ ] [ ] ,
originalEmbedding : number [ ] ,
k : number ,
alpha : number = 0.3
) : string [ ] {
const n = candidates . length ;
const selected : string [ ] = [ ] ;
const selectedEmbeddings : number [ ] [ ] = [ ] ;
// Initialize priority queue
const pq = new PriorityQueue < { index : number ; gain : number } > ( ) ;
// Calculate initial marginal gain for all candidates
for ( let i = 0 ; i < n ; i ++ ) {
const gain = computeMarginalGain ( embeddings [ i ] , selectedEmbeddings , originalEmbedding , alpha ) ;
pq . enqueue ( { index : i , gain } , gain ) ;
}
// Greedy selection
for ( let iteration = 0 ; iteration < k ; iteration ++ ) {
if ( pq . isEmpty ( ) ) break ;
let bestCandidate : { index : number ; gain : number } | undefined ;
// Find candidate with maximum marginal gain
while ( ! pq . isEmpty ( ) ) {
const candidate = pq . dequeue ( ) ! ;
const currentGain = computeMarginalGain (
embeddings [ candidate . index ] ,
selectedEmbeddings ,
originalEmbedding ,
alpha
) ;
if ( currentGain >= candidate . gain ) {
bestCandidate = { index : candidate.index , gain : currentGain } ;
break ;
} else {
pq . enqueue ( candidate , currentGain ) ;
if ( bestCandidate ) {
selected . push ( candidates [ bestCandidate . index ] ) ;
selectedEmbeddings . push ( embeddings [ bestCandidate . index ] ) ;
}
}
if ( bestCandidate ) {
selected . push ( candidates [ bestCandidate . index ] ) ;
selectedEmbeddings . push ( embeddings [ bestCandidate . index ] ) ;
}
return selected ;
}
return selected ;
}
const { vectors , tokens } = await getVectorsByText ( {
model : embeddingModelData ,
input : queries ,
type : 'query'
} ) ;
// Generate embeddings for input texts
async function generateEmbeddings ( texts : string [ ] , model : string ) : Promise < number [ ] [ ] > {
try {
const vectorModel = getEmbeddingModel ( model ) ;
const embeddings : number [ ] [ ] = [ ] ;
const originalEmbedding = vectors [ 0 ] ;
const candidateEmbeddings = vectors . slice ( 1 ) ;
// Select optimal queries using lazy greedy algorithm
const selectedQueries = lazyGreedyQuerySelection (
queries ,
candidateEmbeddings ,
originalEmbedding ,
Math . min ( 5 , queries . length ) , // Select top 5 queries or less
0.3 // alpha parameter for balancing relevance and diversity
) ;
for ( const text of texts ) {
// Use vector model's createEmbedding method
const embedding = await getVectorsByText ( {
model : vectorModel ,
input : text ,
type : 'query'
} ) ;
embeddings . push ( embedding . vectors [ 0 ] ) ;
}
return embeddings ;
} catch ( error ) {
addLog . warn ( 'Failed to generate embeddings' , { error , model } ) ;
throw error ;
}
return {
queries : selectedQueries ,
tokens
} ;
}
const title = global . feConfigs ? . systemTitle || 'FastAI' ;
@@ -254,7 +269,7 @@ assistant: Laf 是一个云函数开发平台。
1. 输出格式为 JSON 数组,数组中每个元素为字符串。无需对输出进行任何解释。
2. 输出语言与原问题相同。原问题为中文则输出中文;原问题为英文则输出英文。
3. 确保生成恰好 {{count}} 个检索词。
3. 确保生成恰好 10 个检索词。
## 开始任务
@@ -269,20 +284,22 @@ export const queryExtension = async ({
chatBg ,
query ,
histories = [ ] ,
m odel,
generateCount = 10 // 添加生成数量参数, 默认为10个
llmM odel,
embeddingModel
} : {
chatBg? : string ;
query : string ;
histories : ChatItemType [ ] ;
m odel : string ;
generateCount? : number ;
llmM odel : string ;
embeddingModel : string ;
} ) : Promise < {
rawQuery : string ;
extensionQueries : string [ ] ;
m odel : string ;
llmM odel : string ;
inputTokens : number ;
outputTokens : number ;
embeddingTokens : number ;
embeddingModel : string ;
} > = > {
const systemFewShot = chatBg
? ` user: 对话背景。
@@ -290,10 +307,12 @@ assistant: ${chatBg}
`
: '' ;
const m odelData = getLLMModel ( m odel) ;
const llmM odelData = getLLMModel ( llmM odel) ;
const embeddingModelData = getEmbeddingModel ( embeddingModel ) ;
const filterHistories = await filterGPTMessageByMaxContext ( {
messages : chats2GPTMessages ( { messages : histories , reserveId : false } ) ,
maxContext : m odelData.maxContext - 1000
maxContext : llmM odelData.maxContext - 1000
} ) ;
const historyFewShot = filterHistories
@@ -317,8 +336,7 @@ assistant: ${chatBg}
role : 'user' ,
content : replaceVariable ( defaultPrompt , {
query : ` ${ query } ` ,
histories : concatFewShot || 'null' ,
count : generateCount.toString ( )
histories : concatFewShot || 'null'
} )
}
] as any ;
@@ -327,11 +345,11 @@ assistant: ${chatBg}
body : llmCompletionsBodyFormat (
{
stream : true ,
model : m odelData.model,
model : llmM odelData.model,
temperature : 0.1 ,
messages
} ,
m odelData
llmM odelData
)
} ) ;
const { text : answer , usage } = await formatLLMResponse ( response ) ;
@@ -342,9 +360,11 @@ assistant: ${chatBg}
return {
rawQuery : query ,
extensionQueries : [ ] ,
model ,
llmModel : llmModelData. model,
inputTokens : inputTokens ,
outputTokens : outputTokens
outputTokens : outputTokens ,
embeddingModel : embeddingModelData.model ,
embeddingTokens : 0
} ;
}
@@ -357,9 +377,11 @@ assistant: ${chatBg}
return {
rawQuery : query ,
extensionQueries : [ ] ,
model ,
llmModel : llmModelData. model,
inputTokens : inputTokens ,
outputTokens : outputTokens
outputTokens : outputTokens ,
embeddingModel : embeddingModelData.model ,
embeddingTokens : 0
} ;
}
@@ -376,44 +398,41 @@ assistant: ${chatBg}
return {
rawQuery : query ,
extensionQueries : [ ] ,
model ,
llmModel : llmModelData. model,
inputTokens ,
outputTokens
outputTokens ,
embeddingModel : embeddingModelData.model ,
embeddingTokens : 0
} ;
}
// Generate embeddings for original query and candidate queries
const allQueries = [ query , . . . queries ] ;
const embeddings = await generateEmbeddings ( allQueries , model ) ;
const originalEmbedding = embeddings [ 0 ] ;
const candidateEmbeddings = embeddings . slice ( 1 ) ;
// Select optimal queries using lazy greedy algorithm
const selectedQueries = lazyGreedyQuerySelection (
queries ,
candidateEmbeddings ,
originalEmbedding ,
Math . min ( 5 , queries . length ) , // Select top 5 queries or less
0.3 // alpha parameter for balancing relevance and diversity
) ;
// Filter query
const { queries : filteredQueries , tokens : embeddingTokens } = await queriesFilter ( {
queries : [ query , . . . queries ] . filter ( Boolean ) ,
embeddingModelData
} ) ;
console . log ( filteredQueries , 111 ) ;
return {
rawQuery : query ,
extensionQueries : select edQueries,
model ,
extensionQueries : filter edQueries,
llmModel : llmModelData. model,
inputTokens ,
outputTokens
outputTokens ,
embeddingModel : embeddingModelData.model ,
embeddingTokens
} ;
} catch ( error ) {
addLog . warn ( ' Query extension failed' , {
error ,
answer
} ) ;
addLog . error ( ` Query extension failed, answer: ${ answer } ` , error ) ;
return {
rawQuery : query ,
extensionQueries : [ ] ,
model ,
llmModel : llmModelData. model,
inputTokens ,
outputTokens
outputTokens ,
embeddingModel : embeddingModelData.model ,
embeddingTokens : 0
} ;
}
} ;