mirror of
https://github.com/labring/FastGPT.git
synced 2025-08-02 12:48:30 +00:00
V4.6.6-2 (#673)
This commit is contained in:
@@ -24,13 +24,24 @@ export function getAudioSpeechModel(model?: string) {
|
||||
);
|
||||
}
|
||||
|
||||
export function getWhisperModel(model?: string) {
|
||||
return global.whisperModel;
|
||||
}
|
||||
|
||||
export function getReRankModel(model?: string) {
|
||||
return global.reRankModels.find((item) => item.model === model);
|
||||
}
|
||||
|
||||
export enum ModelTypeEnum {
|
||||
chat = 'chat',
|
||||
qa = 'qa',
|
||||
cq = 'cq',
|
||||
extract = 'extract',
|
||||
qg = 'qg',
|
||||
vector = 'vector'
|
||||
vector = 'vector',
|
||||
audioSpeech = 'audioSpeech',
|
||||
whisper = 'whisper',
|
||||
rerank = 'rerank'
|
||||
}
|
||||
export const getModelMap = {
|
||||
[ModelTypeEnum.chat]: getChatModel,
|
||||
@@ -38,5 +49,8 @@ export const getModelMap = {
|
||||
[ModelTypeEnum.cq]: getCQModel,
|
||||
[ModelTypeEnum.extract]: getExtractModel,
|
||||
[ModelTypeEnum.qg]: getQGModel,
|
||||
[ModelTypeEnum.vector]: getVectorModel
|
||||
[ModelTypeEnum.vector]: getVectorModel,
|
||||
[ModelTypeEnum.audioSpeech]: getAudioSpeechModel,
|
||||
[ModelTypeEnum.whisper]: getWhisperModel,
|
||||
[ModelTypeEnum.rerank]: getReRankModel
|
||||
};
|
||||
|
@@ -20,8 +20,14 @@ export function reRankRecall({ query, inputs }: PostReRankProps) {
|
||||
Authorization: `Bearer ${model.requestAuth}`
|
||||
}
|
||||
}
|
||||
).then((data) => {
|
||||
console.log('rerank time:', Date.now() - start);
|
||||
return data;
|
||||
});
|
||||
)
|
||||
.then((data) => {
|
||||
console.log('rerank time:', Date.now() - start);
|
||||
return data;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
|
||||
return [];
|
||||
});
|
||||
}
|
||||
|
@@ -1,73 +0,0 @@
|
||||
import { getAIApi } from '@fastgpt/service/core/ai/config';
|
||||
|
||||
export type GetVectorProps = {
|
||||
model: string;
|
||||
input: string | string[];
|
||||
};
|
||||
|
||||
// text to vector
|
||||
export async function getVectorsByText({
|
||||
model = 'text-embedding-ada-002',
|
||||
input
|
||||
}: GetVectorProps) {
|
||||
try {
|
||||
if (typeof input === 'string' && !input) {
|
||||
return Promise.reject({
|
||||
code: 500,
|
||||
message: 'input is empty'
|
||||
});
|
||||
} else if (Array.isArray(input)) {
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
if (!input[i]) {
|
||||
return Promise.reject({
|
||||
code: 500,
|
||||
message: 'input array is empty'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 chatAPI
|
||||
const ai = getAIApi();
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const result = await ai.embeddings
|
||||
.create({
|
||||
model,
|
||||
input
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.data) {
|
||||
return Promise.reject('Embedding API 404');
|
||||
}
|
||||
if (!res?.data?.[0]?.embedding) {
|
||||
console.log(res?.data);
|
||||
// @ts-ignore
|
||||
return Promise.reject(res.data?.err?.message || 'Embedding API Error');
|
||||
}
|
||||
return {
|
||||
tokenLen: res.usage.total_tokens || 0,
|
||||
vectors: await Promise.all(res.data.map((item) => unityDimensional(item.embedding)))
|
||||
};
|
||||
});
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.log(`Embedding Error`, error);
|
||||
|
||||
return Promise.reject(error);
|
||||
}
|
||||
}
|
||||
|
||||
function unityDimensional(vector: number[]) {
|
||||
if (vector.length > 1536) {
|
||||
console.log(`当前向量维度为: ${vector.length}, 向量维度不能超过 1536, 已自动截取前 1536 维度`);
|
||||
return vector.slice(0, 1536);
|
||||
}
|
||||
let resultVector = vector;
|
||||
const vectorLen = vector.length;
|
||||
|
||||
const zeroVector = new Array(1536 - vectorLen).fill(0);
|
||||
|
||||
return resultVector.concat(zeroVector);
|
||||
}
|
@@ -4,12 +4,30 @@ import {
|
||||
PatchIndexesProps,
|
||||
UpdateDatasetDataProps
|
||||
} from '@fastgpt/global/core/dataset/controller';
|
||||
import { deletePgDataById } from '@fastgpt/service/core/dataset/data/pg';
|
||||
import { insertData2Pg, updatePgDataById } from './pg';
|
||||
import {
|
||||
insertDatasetDataVector,
|
||||
recallFromVectorStore,
|
||||
updateDatasetDataVector
|
||||
} from '@fastgpt/service/common/vectorStore/controller';
|
||||
import { Types } from 'mongoose';
|
||||
import { DatasetDataIndexTypeEnum } from '@fastgpt/global/core/dataset/constant';
|
||||
import {
|
||||
DatasetDataIndexTypeEnum,
|
||||
DatasetSearchModeEnum,
|
||||
DatasetSearchModeMap,
|
||||
SearchScoreTypeEnum
|
||||
} from '@fastgpt/global/core/dataset/constant';
|
||||
import { getDefaultIndex } from '@fastgpt/global/core/dataset/utils';
|
||||
import { jiebaSplit } from '../utils';
|
||||
import { jiebaSplit } from '@/service/common/string/jieba';
|
||||
import { deleteDatasetDataVector } from '@fastgpt/service/common/vectorStore/controller';
|
||||
import { getVectorsByText } from '@fastgpt/service/core/ai/embedding';
|
||||
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
|
||||
import {
|
||||
DatasetDataSchemaType,
|
||||
SearchDataResponseItemType
|
||||
} from '@fastgpt/global/core/dataset/type';
|
||||
import { reRankRecall } from '../../ai/rerank';
|
||||
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import { hashStr } from '@fastgpt/global/common/string/tools';
|
||||
|
||||
/* insert data.
|
||||
* 1. create data id
|
||||
@@ -50,17 +68,17 @@ export async function insertData2Dataset({
|
||||
}))
|
||||
: [getDefaultIndex({ q, a })];
|
||||
|
||||
// insert to pg
|
||||
// insert to vector store
|
||||
const result = await Promise.all(
|
||||
indexes.map((item) =>
|
||||
insertData2Pg({
|
||||
mongoDataId: String(id),
|
||||
input: item.text,
|
||||
insertDatasetDataVector({
|
||||
query: item.text,
|
||||
model,
|
||||
teamId,
|
||||
tmbId,
|
||||
datasetId,
|
||||
collectionId
|
||||
collectionId,
|
||||
dataId: String(id)
|
||||
})
|
||||
)
|
||||
);
|
||||
@@ -84,7 +102,7 @@ export async function insertData2Dataset({
|
||||
|
||||
return {
|
||||
insertId: _id,
|
||||
tokenLen: result.reduce((acc, cur) => acc + cur.tokenLen, 0)
|
||||
tokens: result.reduce((acc, cur) => acc + cur.tokens, 0)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -172,35 +190,40 @@ export async function updateData2Dataset({
|
||||
const result = await Promise.all(
|
||||
patchResult.map(async (item) => {
|
||||
if (item.type === 'create') {
|
||||
const result = await insertData2Pg({
|
||||
mongoDataId: dataId,
|
||||
input: item.index.text,
|
||||
const result = await insertDatasetDataVector({
|
||||
query: item.index.text,
|
||||
model,
|
||||
teamId: mongoData.teamId,
|
||||
tmbId: mongoData.tmbId,
|
||||
datasetId: mongoData.datasetId,
|
||||
collectionId: mongoData.collectionId
|
||||
collectionId: mongoData.collectionId,
|
||||
dataId
|
||||
});
|
||||
item.index.dataId = result.insertId;
|
||||
return result;
|
||||
}
|
||||
if (item.type === 'update' && item.index.dataId) {
|
||||
return updatePgDataById({
|
||||
return updateDatasetDataVector({
|
||||
id: item.index.dataId,
|
||||
input: item.index.text,
|
||||
query: item.index.text,
|
||||
model
|
||||
});
|
||||
}
|
||||
if (item.type === 'delete' && item.index.dataId) {
|
||||
return deletePgDataById(['id', item.index.dataId]);
|
||||
await deleteDatasetDataVector({
|
||||
id: item.index.dataId
|
||||
});
|
||||
return {
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
return {
|
||||
tokenLen: 0
|
||||
tokens: 0
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const tokenLen = result.reduce((acc, cur) => acc + cur.tokenLen, 0);
|
||||
const tokens = result.reduce((acc, cur) => acc + cur.tokens, 0);
|
||||
|
||||
// update mongo
|
||||
mongoData.q = q || mongoData.q;
|
||||
@@ -211,6 +234,457 @@ export async function updateData2Dataset({
|
||||
await mongoData.save();
|
||||
|
||||
return {
|
||||
tokenLen
|
||||
tokens
|
||||
};
|
||||
}
|
||||
|
||||
export async function searchDatasetData(props: {
|
||||
model: string;
|
||||
similarity?: number; // min distance
|
||||
limit: number; // max Token limit
|
||||
datasetIds: string[];
|
||||
searchMode?: `${DatasetSearchModeEnum}`;
|
||||
usingReRank?: boolean;
|
||||
rawQuery: string;
|
||||
queries: string[];
|
||||
}) {
|
||||
let {
|
||||
rawQuery,
|
||||
queries,
|
||||
model,
|
||||
similarity = 0,
|
||||
limit: maxTokens,
|
||||
searchMode = DatasetSearchModeEnum.embedding,
|
||||
usingReRank = false,
|
||||
datasetIds = []
|
||||
} = props;
|
||||
|
||||
/* init params */
|
||||
searchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding;
|
||||
usingReRank = usingReRank && global.reRankModels.length > 0;
|
||||
|
||||
// Compatible with topk limit
|
||||
if (maxTokens < 50) {
|
||||
maxTokens = 1500;
|
||||
}
|
||||
let set = new Set<string>();
|
||||
let usingSimilarityFilter = false;
|
||||
|
||||
/* function */
|
||||
const countRecallLimit = () => {
|
||||
const oneChunkToken = 50;
|
||||
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
|
||||
|
||||
// Increase search range, reduce hnsw loss. 20 ~ 100
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 100),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < value < 100
|
||||
if (searchMode === DatasetSearchModeEnum.fullTextRecall) {
|
||||
return {
|
||||
embeddingLimit: 0,
|
||||
fullTextLimit: Math.min(estimatedLen, 50)
|
||||
};
|
||||
}
|
||||
// mixed
|
||||
// 50 < 2*limit < embedding < 80
|
||||
// 20 < limit < fullTextLimit < 40
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 80),
|
||||
fullTextLimit: Math.min(estimatedLen, 40)
|
||||
};
|
||||
};
|
||||
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
|
||||
const { vectors, tokens } = await getVectorsByText({
|
||||
model,
|
||||
input: [query]
|
||||
});
|
||||
|
||||
const { results } = await recallFromVectorStore({
|
||||
vectors,
|
||||
limit,
|
||||
datasetIds
|
||||
});
|
||||
|
||||
// get q and a
|
||||
const [collections, dataList] = await Promise.all([
|
||||
MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: results.map((item) => item.collectionId) }
|
||||
},
|
||||
'name fileId rawLink'
|
||||
).lean(),
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
_id: { $in: results.map((item) => item.dataId?.trim()) }
|
||||
},
|
||||
'datasetId collectionId q a chunkIndex indexes'
|
||||
).lean()
|
||||
]);
|
||||
|
||||
const formatResult = results
|
||||
.map((item, index) => {
|
||||
const collection = collections.find(
|
||||
(collection) => String(collection._id) === item.collectionId
|
||||
);
|
||||
const data = dataList.find((data) => String(data._id) === item.dataId);
|
||||
|
||||
// if collection or data UnExist, the relational mongo data already deleted
|
||||
if (!collection || !data) return null;
|
||||
|
||||
const result: SearchDataResponseItemType = {
|
||||
id: String(data._id),
|
||||
q: data.q,
|
||||
a: data.a,
|
||||
chunkIndex: data.chunkIndex,
|
||||
indexes: data.indexes,
|
||||
datasetId: String(data.datasetId),
|
||||
collectionId: String(data.collectionId),
|
||||
sourceName: collection.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
score: [{ type: SearchScoreTypeEnum.embedding, value: item.score, index }]
|
||||
};
|
||||
|
||||
return result;
|
||||
})
|
||||
.filter((item) => item !== null) as SearchDataResponseItemType[];
|
||||
|
||||
return {
|
||||
embeddingRecallResults: formatResult,
|
||||
tokens
|
||||
};
|
||||
};
|
||||
const fullTextRecall = async ({
|
||||
query,
|
||||
limit
|
||||
}: {
|
||||
query: string;
|
||||
limit: number;
|
||||
}): Promise<{
|
||||
fullTextRecallResults: SearchDataResponseItemType[];
|
||||
tokenLen: number;
|
||||
}> => {
|
||||
if (limit === 0) {
|
||||
return {
|
||||
fullTextRecallResults: [],
|
||||
tokenLen: 0
|
||||
};
|
||||
}
|
||||
|
||||
let searchResults = (
|
||||
await Promise.all(
|
||||
datasetIds.map((id) =>
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
datasetId: id,
|
||||
$text: { $search: jiebaSplit({ text: query }) }
|
||||
},
|
||||
{
|
||||
score: { $meta: 'textScore' },
|
||||
_id: 1,
|
||||
datasetId: 1,
|
||||
collectionId: 1,
|
||||
q: 1,
|
||||
a: 1,
|
||||
indexes: 1,
|
||||
chunkIndex: 1
|
||||
}
|
||||
)
|
||||
.sort({ score: { $meta: 'textScore' } })
|
||||
.limit(limit)
|
||||
.lean()
|
||||
)
|
||||
)
|
||||
).flat() as (DatasetDataSchemaType & { score: number })[];
|
||||
|
||||
// resort
|
||||
searchResults.sort((a, b) => b.score - a.score);
|
||||
searchResults.slice(0, limit);
|
||||
|
||||
const collections = await MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: searchResults.map((item) => item.collectionId) }
|
||||
},
|
||||
'_id name fileId rawLink'
|
||||
);
|
||||
|
||||
return {
|
||||
fullTextRecallResults: searchResults.map((item, index) => {
|
||||
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
|
||||
return {
|
||||
id: String(item._id),
|
||||
datasetId: String(item.datasetId),
|
||||
collectionId: String(item.collectionId),
|
||||
sourceName: collection?.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
q: item.q,
|
||||
a: item.a,
|
||||
chunkIndex: item.chunkIndex,
|
||||
indexes: item.indexes,
|
||||
score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }]
|
||||
};
|
||||
}),
|
||||
tokenLen: 0
|
||||
};
|
||||
};
|
||||
const reRankSearchResult = async ({
|
||||
data,
|
||||
query
|
||||
}: {
|
||||
data: SearchDataResponseItemType[];
|
||||
query: string;
|
||||
}): Promise<SearchDataResponseItemType[]> => {
|
||||
try {
|
||||
const results = await reRankRecall({
|
||||
query,
|
||||
inputs: data.map((item) => ({
|
||||
id: item.id,
|
||||
text: `${item.q}\n${item.a}`
|
||||
}))
|
||||
});
|
||||
|
||||
if (!Array.isArray(results)) return [];
|
||||
|
||||
// add new score to data
|
||||
const mergeResult = results
|
||||
.map((item, index) => {
|
||||
const target = data.find((dataItem) => dataItem.id === item.id);
|
||||
if (!target) return null;
|
||||
const score = item.score || 0;
|
||||
|
||||
return {
|
||||
...target,
|
||||
score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }]
|
||||
};
|
||||
})
|
||||
.filter(Boolean) as SearchDataResponseItemType[];
|
||||
|
||||
return mergeResult;
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
|
||||
const results: SearchDataResponseItemType[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
for (let i = 0; i < list.length; i++) {
|
||||
const item = list[i];
|
||||
totalTokens += countPromptTokens(item.q + item.a);
|
||||
if (totalTokens > maxTokens + 500) {
|
||||
break;
|
||||
}
|
||||
results.push(item);
|
||||
if (totalTokens > maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return results.length === 0 ? list.slice(0, 1) : results;
|
||||
};
|
||||
const multiQueryRecall = async ({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
}: {
|
||||
embeddingLimit: number;
|
||||
fullTextLimit: number;
|
||||
}) => {
|
||||
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
|
||||
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
|
||||
minAmount = Math.min(resultList.length, minAmount);
|
||||
|
||||
const map: Record<
|
||||
string,
|
||||
{
|
||||
amount: number;
|
||||
data: SearchDataResponseItemType;
|
||||
}
|
||||
> = {};
|
||||
|
||||
for (const list of resultList) {
|
||||
for (const item of list) {
|
||||
map[item.id] = map[item.id]
|
||||
? {
|
||||
amount: map[item.id].amount + 1,
|
||||
data: item
|
||||
}
|
||||
: {
|
||||
amount: 1,
|
||||
data: item
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return Object.values(map)
|
||||
.filter((item) => item.amount >= minAmount)
|
||||
.map((item) => item.data);
|
||||
};
|
||||
|
||||
// multi query recall
|
||||
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
|
||||
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
|
||||
let embTokens = 0;
|
||||
for await (const query of queries) {
|
||||
const [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
|
||||
embeddingRecall({
|
||||
query,
|
||||
limit: embeddingLimit
|
||||
}),
|
||||
fullTextRecall({
|
||||
query,
|
||||
limit: fullTextLimit
|
||||
})
|
||||
]);
|
||||
embTokens += tokens;
|
||||
|
||||
embeddingRecallResList.push(embeddingRecallResults);
|
||||
fullTextRecallResList.push(fullTextRecallResults);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: embTokens,
|
||||
embeddingRecallResults: embeddingRecallResList[0],
|
||||
fullTextRecallResults: fullTextRecallResList[0]
|
||||
};
|
||||
};
|
||||
const rrfConcat = (
|
||||
arr: { k: number; list: SearchDataResponseItemType[] }[]
|
||||
): SearchDataResponseItemType[] => {
|
||||
const map = new Map<string, SearchDataResponseItemType & { rrfScore: number }>();
|
||||
|
||||
// rrf
|
||||
arr.forEach((item) => {
|
||||
const k = item.k;
|
||||
|
||||
item.list.forEach((data, index) => {
|
||||
const rank = index + 1;
|
||||
const score = 1 / (k + rank);
|
||||
|
||||
const record = map.get(data.id);
|
||||
if (record) {
|
||||
// 合并两个score,有相同type的score,取最大值
|
||||
const concatScore = [...record.score];
|
||||
for (const dataItem of data.score) {
|
||||
const sameScore = concatScore.find((item) => item.type === dataItem.type);
|
||||
if (sameScore) {
|
||||
sameScore.value = Math.max(sameScore.value, dataItem.value);
|
||||
} else {
|
||||
concatScore.push(dataItem);
|
||||
}
|
||||
}
|
||||
|
||||
map.set(data.id, {
|
||||
...record,
|
||||
score: concatScore,
|
||||
rrfScore: record.rrfScore + score
|
||||
});
|
||||
} else {
|
||||
map.set(data.id, {
|
||||
...data,
|
||||
rrfScore: score
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// sort
|
||||
const mapArray = Array.from(map.values());
|
||||
const results = mapArray.sort((a, b) => b.rrfScore - a.rrfScore);
|
||||
|
||||
return results.map((item, index) => {
|
||||
item.score.push({
|
||||
type: SearchScoreTypeEnum.rrf,
|
||||
value: item.rrfScore,
|
||||
index
|
||||
});
|
||||
// @ts-ignore
|
||||
delete item.rrfScore;
|
||||
return item;
|
||||
});
|
||||
};
|
||||
|
||||
/* main step */
|
||||
// count limit
|
||||
const { embeddingLimit, fullTextLimit } = countRecallLimit();
|
||||
|
||||
// recall
|
||||
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
});
|
||||
|
||||
// ReRank results
|
||||
const reRankResults = await (async () => {
|
||||
if (!usingReRank) return [];
|
||||
|
||||
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
|
||||
const concatRecallResults = embeddingRecallResults.concat(
|
||||
fullTextRecallResults.filter((item) => !set.has(item.id))
|
||||
);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = concatRecallResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
return reRankSearchResult({
|
||||
query: rawQuery,
|
||||
data: filterSameDataResults
|
||||
});
|
||||
})();
|
||||
|
||||
// embedding recall and fullText recall rrf concat
|
||||
const rrfConcatResults = rrfConcat([
|
||||
{ k: 60, list: embeddingRecallResults },
|
||||
{ k: 60, list: fullTextRecallResults },
|
||||
{ k: 60, list: reRankResults }
|
||||
]);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = rrfConcatResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
|
||||
// score filter
|
||||
const scoreFilter = (() => {
|
||||
if (usingReRank) {
|
||||
usingSimilarityFilter = true;
|
||||
|
||||
return filterSameDataResults.filter((item) => {
|
||||
const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank);
|
||||
if (reRankScore && reRankScore.value < similarity) return false;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return filterSameDataResults.filter((item) => {
|
||||
usingSimilarityFilter = true;
|
||||
|
||||
const embeddingScore = item.score.find(
|
||||
(item) => item.type === SearchScoreTypeEnum.embedding
|
||||
);
|
||||
if (embeddingScore && embeddingScore.value < similarity) return false;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
return filterSameDataResults;
|
||||
})();
|
||||
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(scoreFilter, maxTokens),
|
||||
tokens,
|
||||
usingSimilarityFilter
|
||||
};
|
||||
}
|
||||
|
@@ -1,478 +0,0 @@
|
||||
import { DatasetSearchModeEnum, PgDatasetTableName } from '@fastgpt/global/core/dataset/constant';
|
||||
import type {
|
||||
DatasetDataSchemaType,
|
||||
SearchDataResponseItemType
|
||||
} from '@fastgpt/global/core/dataset/type.d';
|
||||
import { PgClient } from '@fastgpt/service/common/pg';
|
||||
import { getVectorsByText } from '@/service/core/ai/vector';
|
||||
import { delay } from '@fastgpt/global/common/system/utils';
|
||||
import { PgSearchRawType } from '@fastgpt/global/core/dataset/api';
|
||||
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
|
||||
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
|
||||
import { jiebaSplit } from '../utils';
|
||||
import { reRankRecall } from '../../ai/rerank';
|
||||
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import { hashStr } from '@fastgpt/global/common/string/tools';
|
||||
|
||||
export async function insertData2Pg(props: {
|
||||
mongoDataId: string;
|
||||
input: string;
|
||||
model: string;
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
datasetId: string;
|
||||
collectionId: string;
|
||||
retry?: number;
|
||||
}): Promise<{ insertId: string; vectors: number[][]; tokenLen: number }> {
|
||||
const { mongoDataId, input, model, teamId, tmbId, datasetId, collectionId, retry = 3 } = props;
|
||||
try {
|
||||
// get vector
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [input]
|
||||
});
|
||||
const { rows } = await PgClient.insert(PgDatasetTableName, {
|
||||
values: [
|
||||
[
|
||||
{ key: 'vector', value: `[${vectors[0]}]` },
|
||||
{ key: 'team_id', value: String(teamId) },
|
||||
{ key: 'tmb_id', value: String(tmbId) },
|
||||
{ key: 'dataset_id', value: datasetId },
|
||||
{ key: 'collection_id', value: collectionId },
|
||||
{ key: 'data_id', value: String(mongoDataId) }
|
||||
]
|
||||
]
|
||||
});
|
||||
return {
|
||||
insertId: rows[0].id,
|
||||
vectors,
|
||||
tokenLen
|
||||
};
|
||||
} catch (error) {
|
||||
if (retry <= 0) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
await delay(500);
|
||||
return insertData2Pg({
|
||||
...props,
|
||||
retry: retry - 1
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export async function updatePgDataById({
|
||||
id,
|
||||
input,
|
||||
model
|
||||
}: {
|
||||
id: string;
|
||||
input: string;
|
||||
model: string;
|
||||
}) {
|
||||
let retry = 2;
|
||||
async function updatePg(): Promise<{ vectors: number[][]; tokenLen: number }> {
|
||||
try {
|
||||
// get vector
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [input]
|
||||
});
|
||||
// update pg
|
||||
await PgClient.update(PgDatasetTableName, {
|
||||
where: [['id', id]],
|
||||
values: [{ key: 'vector', value: `[${vectors[0]}]` }]
|
||||
});
|
||||
return {
|
||||
vectors,
|
||||
tokenLen
|
||||
};
|
||||
} catch (error) {
|
||||
if (--retry < 0) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
await delay(500);
|
||||
return updatePg();
|
||||
}
|
||||
}
|
||||
return updatePg();
|
||||
}
|
||||
|
||||
// ------------------ search start ------------------
|
||||
type SearchProps = {
|
||||
model: string;
|
||||
similarity?: number; // min distance
|
||||
limit: number; // max Token limit
|
||||
datasetIds: string[];
|
||||
searchMode?: `${DatasetSearchModeEnum}`;
|
||||
};
|
||||
export async function searchDatasetData(
|
||||
props: SearchProps & { rawQuery: string; queries: string[] }
|
||||
) {
|
||||
let {
|
||||
rawQuery,
|
||||
queries,
|
||||
model,
|
||||
similarity = 0,
|
||||
limit: maxTokens,
|
||||
searchMode = DatasetSearchModeEnum.embedding,
|
||||
datasetIds = []
|
||||
} = props;
|
||||
|
||||
/* init params */
|
||||
searchMode = global.systemEnv?.pluginBaseUrl ? searchMode : DatasetSearchModeEnum.embedding;
|
||||
// Compatible with topk limit
|
||||
if (maxTokens < 50) {
|
||||
maxTokens = 1500;
|
||||
}
|
||||
const rerank =
|
||||
global.reRankModels?.[0] &&
|
||||
(searchMode === DatasetSearchModeEnum.embeddingReRank ||
|
||||
searchMode === DatasetSearchModeEnum.embFullTextReRank);
|
||||
let set = new Set<string>();
|
||||
|
||||
/* function */
|
||||
const countRecallLimit = () => {
|
||||
const oneChunkToken = 50;
|
||||
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
|
||||
|
||||
// Increase search range, reduce hnsw loss. 20 ~ 100
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 100),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < value < 100
|
||||
if (searchMode === DatasetSearchModeEnum.embeddingReRank) {
|
||||
return {
|
||||
embeddingLimit: Math.min(100, Math.max(50, estimatedLen * 2)),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < embedding < 80
|
||||
// 20 < limit < fullTextLimit < 40
|
||||
return {
|
||||
embeddingLimit: Math.min(80, Math.max(50, estimatedLen * 2)),
|
||||
fullTextLimit: Math.min(40, Math.max(20, estimatedLen))
|
||||
};
|
||||
};
|
||||
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [query]
|
||||
});
|
||||
|
||||
const results: any = await PgClient.query(
|
||||
`BEGIN;
|
||||
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
|
||||
select id, collection_id, data_id, (vector <#> '[${vectors[0]}]') * -1 AS score
|
||||
from ${PgDatasetTableName}
|
||||
where dataset_id IN (${datasetIds.map((id) => `'${String(id)}'`).join(',')})
|
||||
${rerank ? '' : `AND vector <#> '[${vectors[0]}]' < -${similarity}`}
|
||||
order by score desc limit ${limit};
|
||||
COMMIT;`
|
||||
);
|
||||
|
||||
const rows = results?.[2]?.rows as PgSearchRawType[];
|
||||
|
||||
// concat same data_id
|
||||
const filterRows: PgSearchRawType[] = [];
|
||||
let set = new Set<string>();
|
||||
for (const row of rows) {
|
||||
if (!set.has(row.data_id)) {
|
||||
filterRows.push(row);
|
||||
set.add(row.data_id);
|
||||
}
|
||||
}
|
||||
|
||||
// get q and a
|
||||
const [collections, dataList] = await Promise.all([
|
||||
MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: filterRows.map((item) => item.collection_id) }
|
||||
},
|
||||
'name fileId rawLink'
|
||||
).lean(),
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
_id: { $in: filterRows.map((item) => item.data_id?.trim()) }
|
||||
},
|
||||
'datasetId collectionId q a chunkIndex indexes'
|
||||
).lean()
|
||||
]);
|
||||
const formatResult = filterRows
|
||||
.map((item) => {
|
||||
const collection = collections.find(
|
||||
(collection) => String(collection._id) === item.collection_id
|
||||
);
|
||||
const data = dataList.find((data) => String(data._id) === item.data_id);
|
||||
|
||||
// if collection or data UnExist, the relational mongo data already deleted
|
||||
if (!collection || !data) return null;
|
||||
|
||||
return {
|
||||
id: String(data._id),
|
||||
q: data.q,
|
||||
a: data.a,
|
||||
chunkIndex: data.chunkIndex,
|
||||
indexes: data.indexes,
|
||||
datasetId: String(data.datasetId),
|
||||
collectionId: String(data.collectionId),
|
||||
sourceName: collection.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
score: item.score
|
||||
};
|
||||
})
|
||||
.filter((item) => item !== null) as SearchDataResponseItemType[];
|
||||
|
||||
return {
|
||||
embeddingRecallResults: formatResult,
|
||||
tokenLen
|
||||
};
|
||||
};
|
||||
const fullTextRecall = async ({
|
||||
query,
|
||||
limit
|
||||
}: {
|
||||
query: string;
|
||||
limit: number;
|
||||
}): Promise<{
|
||||
fullTextRecallResults: SearchDataResponseItemType[];
|
||||
tokenLen: number;
|
||||
}> => {
|
||||
if (limit === 0) {
|
||||
return {
|
||||
fullTextRecallResults: [],
|
||||
tokenLen: 0
|
||||
};
|
||||
}
|
||||
|
||||
let searchResults = (
|
||||
await Promise.all(
|
||||
datasetIds.map((id) =>
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
datasetId: id,
|
||||
$text: { $search: jiebaSplit({ text: query }) }
|
||||
},
|
||||
{
|
||||
score: { $meta: 'textScore' },
|
||||
_id: 1,
|
||||
datasetId: 1,
|
||||
collectionId: 1,
|
||||
q: 1,
|
||||
a: 1,
|
||||
indexes: 1,
|
||||
chunkIndex: 1
|
||||
}
|
||||
)
|
||||
.sort({ score: { $meta: 'textScore' } })
|
||||
.limit(limit)
|
||||
.lean()
|
||||
)
|
||||
)
|
||||
).flat() as (DatasetDataSchemaType & { score: number })[];
|
||||
|
||||
// resort
|
||||
searchResults.sort((a, b) => b.score - a.score);
|
||||
searchResults.slice(0, limit);
|
||||
|
||||
const collections = await MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: searchResults.map((item) => item.collectionId) }
|
||||
},
|
||||
'_id name fileId rawLink'
|
||||
);
|
||||
|
||||
return {
|
||||
fullTextRecallResults: searchResults.map((item) => {
|
||||
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
|
||||
return {
|
||||
id: String(item._id),
|
||||
datasetId: String(item.datasetId),
|
||||
collectionId: String(item.collectionId),
|
||||
sourceName: collection?.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
q: item.q,
|
||||
a: item.a,
|
||||
chunkIndex: item.chunkIndex,
|
||||
indexes: item.indexes,
|
||||
// @ts-ignore
|
||||
score: item.score
|
||||
};
|
||||
}),
|
||||
tokenLen: 0
|
||||
};
|
||||
};
|
||||
const reRankSearchResult = async ({
|
||||
data,
|
||||
query
|
||||
}: {
|
||||
data: SearchDataResponseItemType[];
|
||||
query: string;
|
||||
}): Promise<SearchDataResponseItemType[]> => {
|
||||
try {
|
||||
const results = await reRankRecall({
|
||||
query,
|
||||
inputs: data.map((item) => ({
|
||||
id: item.id,
|
||||
text: `${item.q}\n${item.a}`
|
||||
}))
|
||||
});
|
||||
|
||||
if (!Array.isArray(results)) return data;
|
||||
|
||||
// add new score to data
|
||||
const mergeResult = results
|
||||
.map((item) => {
|
||||
const target = data.find((dataItem) => dataItem.id === item.id);
|
||||
if (!target) return null;
|
||||
return {
|
||||
...target,
|
||||
score: item.score || target.score
|
||||
};
|
||||
})
|
||||
.filter(Boolean) as SearchDataResponseItemType[];
|
||||
|
||||
return mergeResult;
|
||||
} catch (error) {
|
||||
return data;
|
||||
}
|
||||
};
|
||||
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
|
||||
const results: SearchDataResponseItemType[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
for (let i = 0; i < list.length; i++) {
|
||||
const item = list[i];
|
||||
totalTokens += countPromptTokens(item.q + item.a);
|
||||
if (totalTokens > maxTokens + 500) {
|
||||
break;
|
||||
}
|
||||
results.push(item);
|
||||
if (totalTokens > maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return results.length === 0 ? list.slice(0, 1) : results;
|
||||
};
|
||||
const multiQueryRecall = async ({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
}: {
|
||||
embeddingLimit: number;
|
||||
fullTextLimit: number;
|
||||
}) => {
|
||||
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
|
||||
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
|
||||
minAmount = Math.min(resultList.length, minAmount);
|
||||
|
||||
const map: Record<
|
||||
string,
|
||||
{
|
||||
amount: number;
|
||||
data: SearchDataResponseItemType;
|
||||
}
|
||||
> = {};
|
||||
|
||||
for (const list of resultList) {
|
||||
for (const item of list) {
|
||||
map[item.id] = map[item.id]
|
||||
? {
|
||||
amount: map[item.id].amount + 1,
|
||||
data: item
|
||||
}
|
||||
: {
|
||||
amount: 1,
|
||||
data: item
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return Object.values(map)
|
||||
.filter((item) => item.amount >= minAmount)
|
||||
.map((item) => item.data);
|
||||
};
|
||||
|
||||
// multi query recall
|
||||
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
|
||||
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
|
||||
let embTokens = 0;
|
||||
for await (const query of queries) {
|
||||
const [{ tokenLen, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
|
||||
embeddingRecall({
|
||||
query,
|
||||
limit: embeddingLimit
|
||||
}),
|
||||
fullTextRecall({
|
||||
query,
|
||||
limit: fullTextLimit
|
||||
})
|
||||
]);
|
||||
embTokens += tokenLen;
|
||||
|
||||
embeddingRecallResList.push(embeddingRecallResults);
|
||||
fullTextRecallResList.push(fullTextRecallResults);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: embTokens,
|
||||
embeddingRecallResults: getIntersection(embeddingRecallResList, 2),
|
||||
fullTextRecallResults: getIntersection(fullTextRecallResList, 2)
|
||||
};
|
||||
};
|
||||
|
||||
/* main step */
|
||||
// count limit
|
||||
const { embeddingLimit, fullTextLimit } = countRecallLimit();
|
||||
|
||||
// recall
|
||||
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
});
|
||||
|
||||
// concat recall results
|
||||
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
|
||||
const concatRecallResults = embeddingRecallResults.concat(
|
||||
fullTextRecallResults.filter((item) => !set.has(item.id))
|
||||
);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = concatRecallResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
|
||||
if (!rerank) {
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(
|
||||
filterSameDataResults.filter((item) => item.score >= similarity),
|
||||
maxTokens
|
||||
),
|
||||
tokenLen: tokens
|
||||
};
|
||||
}
|
||||
|
||||
// ReRank results
|
||||
const reRankResults = (
|
||||
await reRankSearchResult({
|
||||
query: rawQuery,
|
||||
data: filterSameDataResults
|
||||
})
|
||||
).filter((item) => item.score > similarity);
|
||||
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(
|
||||
reRankResults.filter((item) => item.score >= similarity),
|
||||
maxTokens
|
||||
),
|
||||
tokenLen: tokens
|
||||
};
|
||||
}
|
||||
// ------------------ search end ------------------
|
@@ -1,6 +1,4 @@
|
||||
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
|
||||
import { cut } from '@node-rs/jieba';
|
||||
import { stopWords } from '@fastgpt/global/common/string/jieba';
|
||||
|
||||
/**
|
||||
* Same value judgment
|
||||
@@ -24,14 +22,3 @@ export async function hasSameValue({
|
||||
return Promise.reject('已经存在完全一致的数据');
|
||||
}
|
||||
}
|
||||
|
||||
export function jiebaSplit({ text }: { text: string }) {
|
||||
const tokens = cut(text, true);
|
||||
|
||||
return (
|
||||
tokens
|
||||
.map((item) => item.replace(/[^\u4e00-\u9fa5a-zA-Z0-9\s]/g, '').trim())
|
||||
.filter((item) => item && !stopWords.has(item))
|
||||
.join(' ') || ''
|
||||
);
|
||||
}
|
||||
|
Reference in New Issue
Block a user