import { DatasetSearchModeEnum, DatasetSearchModeMap, SearchScoreTypeEnum } from '@fastgpt/global/core/dataset/constants'; import { recallFromVectorStore } from '../../../common/vectorStore/controller'; import { getVectorsByText } from '../../ai/embedding'; import { getVectorModel } from '../../ai/model'; import { MongoDatasetData } from '../data/schema'; import { DatasetDataSchemaType, DatasetDataWithCollectionType, SearchDataResponseItemType } from '@fastgpt/global/core/dataset/type'; import { DatasetColCollectionName, MongoDatasetCollection } from '../collection/schema'; import { reRankRecall } from '../../../core/ai/rerank'; import { countPromptTokens } from '../../../common/string/tiktoken/index'; import { datasetSearchResultConcat } from '@fastgpt/global/core/dataset/search/utils'; import { hashStr } from '@fastgpt/global/common/string/tools'; import { jiebaSplit } from '../../../common/string/jieba'; import { getCollectionSourceData } from '@fastgpt/global/core/dataset/collection/utils'; import { Types } from '../../../common/mongo'; type SearchDatasetDataProps = { teamId: string; model: string; similarity?: number; // min distance limit: number; // max Token limit datasetIds: string[]; searchMode?: `${DatasetSearchModeEnum}`; usingReRank?: boolean; reRankQuery: string; queries: string[]; }; export async function searchDatasetData(props: SearchDatasetDataProps) { let { teamId, reRankQuery, queries, model, similarity = 0, limit: maxTokens, searchMode = DatasetSearchModeEnum.embedding, usingReRank = false, datasetIds = [] } = props; /* init params */ searchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding; usingReRank = usingReRank && global.reRankModels.length > 0; // Compatible with topk limit let set = new Set(); let usingSimilarityFilter = false; /* function */ const countRecallLimit = () => { if (searchMode === DatasetSearchModeEnum.embedding) { return { embeddingLimit: 100, fullTextLimit: 0 }; } if (searchMode === DatasetSearchModeEnum.fullTextRecall) { return { embeddingLimit: 0, fullTextLimit: 100 }; } return { embeddingLimit: 80, fullTextLimit: 60 }; }; const getForbidData = async () => { const collections = await MongoDatasetCollection.find( { teamId, datasetId: { $in: datasetIds }, forbid: true }, '_id' ); return { forbidCollectionIdList: collections.map((item) => String(item._id)) }; }; const embeddingRecall = async ({ query, limit, forbidCollectionIdList }: { query: string; limit: number; forbidCollectionIdList: string[]; }) => { const { vectors, tokens } = await getVectorsByText({ model: getVectorModel(model), input: query, type: 'query' }); const { results } = await recallFromVectorStore({ teamId, datasetIds, vector: vectors[0], limit, forbidCollectionIdList }); // get q and a const dataList = (await MongoDatasetData.find( { teamId, datasetId: { $in: datasetIds }, collectionId: { $in: Array.from(new Set(results.map((item) => item.collectionId))) }, 'indexes.dataId': { $in: results.map((item) => item.id?.trim()) } }, 'datasetId collectionId q a chunkIndex indexes' ) .populate('collectionId', 'name fileId rawLink externalFileId externalFileUrl') .lean()) as DatasetDataWithCollectionType[]; // add score to data(It's already sorted. The first one is the one with the most points) const concatResults = dataList.map((data) => { const dataIdList = data.indexes.map((item) => item.dataId); const maxScoreResult = results.find((item) => { return dataIdList.includes(item.id); }); return { ...data, score: maxScoreResult?.score || 0 }; }); concatResults.sort((a, b) => b.score - a.score); const formatResult = concatResults.map((data, index) => { if (!data.collectionId) { console.log('Collection is not found', data); } const result: SearchDataResponseItemType = { id: String(data._id), q: data.q, a: data.a, chunkIndex: data.chunkIndex, datasetId: String(data.datasetId), collectionId: String(data.collectionId?._id), ...getCollectionSourceData(data.collectionId), score: [{ type: SearchScoreTypeEnum.embedding, value: data.score, index }] }; return result; }); return { embeddingRecallResults: formatResult, tokens }; }; const fullTextRecall = async ({ query, limit }: { query: string; limit: number; }): Promise<{ fullTextRecallResults: SearchDataResponseItemType[]; tokenLen: number; }> => { if (limit === 0) { return { fullTextRecallResults: [], tokenLen: 0 }; } let searchResults = ( await Promise.all( datasetIds.map(async (id) => { return MongoDatasetData.aggregate([ { $match: { teamId: new Types.ObjectId(teamId), datasetId: new Types.ObjectId(id), $text: { $search: jiebaSplit({ text: query }) } } }, { $addFields: { score: { $meta: 'textScore' } } }, { $sort: { score: { $meta: 'textScore' } } }, { $limit: limit }, { $lookup: { from: DatasetColCollectionName, let: { collectionId: '$collectionId' }, pipeline: [ { $match: { $expr: { $eq: ['$_id', '$$collectionId'] }, forbid: { $eq: true } // 匹配被禁用的数据 } }, { $project: { _id: 1 // 只需要_id字段来确认匹配 } } ], as: 'collection' } }, { $match: { collection: { $eq: [] } // 没有 forbid=true 的数据 } }, { $project: { _id: 1, datasetId: 1, collectionId: 1, q: 1, a: 1, chunkIndex: 1, score: 1 } } ]); }) ) ).flat() as (DatasetDataSchemaType & { score: number })[]; // resort searchResults.sort((a, b) => b.score - a.score); searchResults.slice(0, limit); const collections = await MongoDatasetCollection.find( { _id: { $in: searchResults.map((item) => item.collectionId) } }, '_id name fileId rawLink' ); return { fullTextRecallResults: searchResults.map((item, index) => { const collection = collections.find((col) => String(col._id) === String(item.collectionId)); return { id: String(item._id), datasetId: String(item.datasetId), collectionId: String(item.collectionId), ...getCollectionSourceData(collection), q: item.q, a: item.a, chunkIndex: item.chunkIndex, indexes: item.indexes, score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }] }; }), tokenLen: 0 }; }; const reRankSearchResult = async ({ data, query }: { data: SearchDataResponseItemType[]; query: string; }): Promise => { try { const results = await reRankRecall({ query, documents: data.map((item) => ({ id: item.id, text: `${item.q}\n${item.a}` })) }); if (results.length === 0) { usingReRank = false; return []; } // add new score to data const mergeResult = results .map((item, index) => { const target = data.find((dataItem) => dataItem.id === item.id); if (!target) return null; const score = item.score || 0; return { ...target, score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }] }; }) .filter(Boolean) as SearchDataResponseItemType[]; return mergeResult; } catch (error) { usingReRank = false; return []; } }; const multiQueryRecall = async ({ embeddingLimit, fullTextLimit }: { embeddingLimit: number; fullTextLimit: number; }) => { // multi query recall const embeddingRecallResList: SearchDataResponseItemType[][] = []; const fullTextRecallResList: SearchDataResponseItemType[][] = []; let totalTokens = 0; const { forbidCollectionIdList } = await getForbidData(); await Promise.all( queries.map(async (query) => { const [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([ embeddingRecall({ query, limit: embeddingLimit, forbidCollectionIdList }), fullTextRecall({ query, limit: fullTextLimit }) ]); totalTokens += tokens; embeddingRecallResList.push(embeddingRecallResults); fullTextRecallResList.push(fullTextRecallResults); }) ); // rrf concat const rrfEmbRecall = datasetSearchResultConcat( embeddingRecallResList.map((list) => ({ k: 60, list })) ).slice(0, embeddingLimit); const rrfFTRecall = datasetSearchResultConcat( fullTextRecallResList.map((list) => ({ k: 60, list })) ).slice(0, fullTextLimit); return { tokens: totalTokens, embeddingRecallResults: rrfEmbRecall, fullTextRecallResults: rrfFTRecall }; }; /* main step */ // count limit const { embeddingLimit, fullTextLimit } = countRecallLimit(); // recall const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({ embeddingLimit, fullTextLimit }); // ReRank results const reRankResults = await (async () => { if (!usingReRank) return []; set = new Set(embeddingRecallResults.map((item) => item.id)); const concatRecallResults = embeddingRecallResults.concat( fullTextRecallResults.filter((item) => !set.has(item.id)) ); // remove same q and a data set = new Set(); const filterSameDataResults = concatRecallResults.filter((item) => { // 删除所有的标点符号与空格等,只对文本进行比较 const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, '')); if (set.has(str)) return false; set.add(str); return true; }); return reRankSearchResult({ query: reRankQuery, data: filterSameDataResults }); })(); // embedding recall and fullText recall rrf concat const rrfConcatResults = datasetSearchResultConcat([ { k: 60, list: embeddingRecallResults }, { k: 60, list: fullTextRecallResults }, { k: 58, list: reRankResults } ]); // remove same q and a data set = new Set(); const filterSameDataResults = rrfConcatResults.filter((item) => { // 删除所有的标点符号与空格等,只对文本进行比较 const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, '')); if (set.has(str)) return false; set.add(str); return true; }); // score filter const scoreFilter = (() => { if (usingReRank) { usingSimilarityFilter = true; return filterSameDataResults.filter((item) => { const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank); if (reRankScore && reRankScore.value < similarity) return false; return true; }); } if (searchMode === DatasetSearchModeEnum.embedding) { usingSimilarityFilter = true; return filterSameDataResults.filter((item) => { const embeddingScore = item.score.find( (item) => item.type === SearchScoreTypeEnum.embedding ); if (embeddingScore && embeddingScore.value < similarity) return false; return true; }); } return filterSameDataResults; })(); // token filter const filterMaxTokensResult = await (async () => { const results: SearchDataResponseItemType[] = []; let totalTokens = 0; for await (const item of scoreFilter) { totalTokens += await countPromptTokens(item.q + item.a); if (totalTokens > maxTokens + 500) { break; } results.push(item); if (totalTokens > maxTokens) { break; } } return results.length === 0 ? scoreFilter.slice(0, 1) : results; })(); return { searchRes: filterMaxTokensResult, tokens, searchMode, limit: maxTokens, similarity, usingReRank, usingSimilarityFilter }; }