From e4756c76dded3fc810a626063d520aacc44d098c Mon Sep 17 00:00:00 2001 From: YeYuheng <57035043+YYH211@users.noreply.github.com> Date: Fri, 29 Aug 2025 00:54:29 +0800 Subject: [PATCH] rrf_weight (#5551) Co-authored-by: xxYyh --- .gitignore | 4 +++- packages/global/core/dataset/search/utils.ts | 7 +++---- .../service/core/dataset/search/controller.ts | 20 +++++++++---------- .../core/workflow/dispatch/dataset/concat.ts | 2 +- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 0275e3f35..844017d54 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,6 @@ files/helm/fastgpt/charts/*.tgz tmp/ coverage -document/.source \ No newline at end of file +document/.source + +bun.lock \ No newline at end of file diff --git a/packages/global/core/dataset/search/utils.ts b/packages/global/core/dataset/search/utils.ts index 27b7aae61..52318c837 100644 --- a/packages/global/core/dataset/search/utils.ts +++ b/packages/global/core/dataset/search/utils.ts @@ -3,7 +3,7 @@ import { type SearchDataResponseItemType } from '../type'; /* dataset search result concat */ export const datasetSearchResultConcat = ( - arr: { k: number; list: SearchDataResponseItemType[] }[] + arr: { weight: number; list: SearchDataResponseItemType[] }[] ): SearchDataResponseItemType[] => { arr = arr.filter((item) => item.list.length > 0); @@ -14,12 +14,11 @@ export const datasetSearchResultConcat = ( // rrf arr.forEach((item) => { - const k = item.k; + const weight = item.weight; item.list.forEach((data, index) => { const rank = index + 1; - const score = 1 / (k + rank); - + const score = (weight * 1) / (60 + rank); const record = map.get(data.id); if (record) { // 合并两个score,有相同type的score,取最大值 diff --git a/packages/service/core/dataset/search/controller.ts b/packages/service/core/dataset/search/controller.ts index 080e6fe2e..8273bafb2 100644 --- a/packages/service/core/dataset/search/controller.ts +++ b/packages/service/core/dataset/search/controller.ts @@ -784,10 +784,10 @@ export async function searchDatasetData( // rrf concat const rrfEmbRecall = datasetSearchResultConcat( - embeddingRecallResults.map((list) => ({ k: 60, list })) + embeddingRecallResults.map((list) => ({ weight: 1, list })) ).slice(0, embeddingLimit); const rrfFTRecall = datasetSearchResultConcat( - fullTextRecallResults.map((list) => ({ k: 60, list })) + fullTextRecallResults.map((list) => ({ weight: 1, list })) ).slice(0, fullTextLimit); return { @@ -850,24 +850,22 @@ export async function searchDatasetData( })(); // embedding recall and fullText recall rrf concat - const baseK = 120; - const embK = Math.round(baseK * (1 - embeddingWeight)); // 搜索结果的 k 值 - const fullTextK = Math.round(baseK * embeddingWeight); // rerank 结果的 k 值 + const embWeight = embeddingWeight; // 向量索引的 weight 大小 + const fullTextWeight = 1 - embeddingWeight; // 全文索引的 weight 大小 const rrfSearchResult = datasetSearchResultConcat([ - { k: embK, list: embeddingRecallResults }, - { k: fullTextK, list: fullTextRecallResults } + { weight: embWeight, list: embeddingRecallResults }, + { weight: fullTextWeight, list: fullTextRecallResults } ]); const rrfConcatResults = (() => { if (reRankResults.length === 0) return rrfSearchResult; if (rerankWeight === 1) return reRankResults; - const searchK = Math.round(baseK * rerankWeight); // 搜索结果的 k 值 - const rerankK = Math.round(baseK * (1 - rerankWeight)); // rerank 结果的 k 值 + const searchWeight = 1 - rerankWeight; // 搜索结果的 weight 大小 return datasetSearchResultConcat([ - { k: searchK, list: rrfSearchResult }, - { k: rerankK, list: reRankResults } + { weight: searchWeight, list: rrfSearchResult }, + { weight: rerankWeight, list: reRankResults } ]); })(); diff --git a/packages/service/core/workflow/dispatch/dataset/concat.ts b/packages/service/core/workflow/dispatch/dataset/concat.ts index cbbe4fa6c..796897f57 100644 --- a/packages/service/core/workflow/dispatch/dataset/concat.ts +++ b/packages/service/core/workflow/dispatch/dataset/concat.ts @@ -29,7 +29,7 @@ export async function dispatchDatasetConcat( const rrfConcatResults = datasetSearchResultConcat( quoteList.map((list) => ({ - k: 60, + weight: 1, list })) );