4.6.5- CoreferenceResolution Module (#631)

This commit is contained in:
Archer
2023-12-22 10:47:31 +08:00
committed by GitHub
parent 41115a96c0
commit cd682d4275
112 changed files with 4163 additions and 2700 deletions

View File

@@ -1,62 +1,26 @@
import {
defaultAudioSpeechModels,
defaultChatModels,
defaultCQModels,
defaultExtractModels,
defaultQAModels,
defaultQGModels,
defaultVectorModels
} from '@fastgpt/global/core/ai/model';
export const getChatModel = (model?: string) => {
return (
(global.chatModels || defaultChatModels).find((item) => item.model === model) ||
global.chatModels?.[0] ||
defaultChatModels[0]
);
return global.chatModels.find((item) => item.model === model) ?? global.chatModels[0];
};
export const getQAModel = (model?: string) => {
return (
(global.qaModels || defaultQAModels).find((item) => item.model === model) ||
global.qaModels?.[0] ||
defaultQAModels[0]
);
return global.qaModels.find((item) => item.model === model) || global.qaModels[0];
};
export const getCQModel = (model?: string) => {
return (
(global.cqModels || defaultCQModels).find((item) => item.model === model) ||
global.cqModels?.[0] ||
defaultCQModels[0]
);
return global.cqModels.find((item) => item.model === model) || global.cqModels[0];
};
export const getExtractModel = (model?: string) => {
return (
(global.extractModels || defaultExtractModels).find((item) => item.model === model) ||
global.extractModels?.[0] ||
defaultExtractModels[0]
);
return global.extractModels.find((item) => item.model === model) || global.extractModels[0];
};
export const getQGModel = (model?: string) => {
return (
(global.qgModels || defaultQGModels).find((item) => item.model === model) ||
global.qgModels?.[0] ||
defaultQGModels[0]
);
return global.qgModels.find((item) => item.model === model) || global.qgModels[0];
};
export const getVectorModel = (model?: string) => {
return (
global.vectorModels.find((item) => item.model === model) ||
global.vectorModels?.[0] ||
defaultVectorModels[0]
);
return global.vectorModels.find((item) => item.model === model) || global.vectorModels[0];
};
export function getAudioSpeechModel(model?: string) {
return (
global.audioSpeechModels.find((item) => item.model === model) ||
global.audioSpeechModels?.[0] ||
defaultAudioSpeechModels[0]
global.audioSpeechModels.find((item) => item.model === model) || global.audioSpeechModels[0]
);
}

View File

@@ -12,6 +12,7 @@ import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
import { jiebaSplit } from '../utils';
import { reRankRecall } from '../../ai/rerank';
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
import { hashStr } from '@fastgpt/global/common/string/tools';
export async function insertData2Pg(props: {
mongoDataId: string;
@@ -98,34 +99,40 @@ export async function updatePgDataById({
// ------------------ search start ------------------
type SearchProps = {
text: string;
model: string;
similarity?: number; // min distance
limit: number; // max Token limit
datasetIds: string[];
searchMode?: `${DatasetSearchModeEnum}`;
};
export async function searchDatasetData(props: SearchProps) {
export async function searchDatasetData(
props: SearchProps & { rawQuery: string; queries: string[] }
) {
let {
text,
rawQuery,
queries,
model,
similarity = 0,
limit: maxTokens,
searchMode = DatasetSearchModeEnum.embedding
searchMode = DatasetSearchModeEnum.embedding,
datasetIds = []
} = props;
searchMode = global.systemEnv?.pluginBaseUrl ? searchMode : DatasetSearchModeEnum.embedding;
/* init params */
searchMode = global.systemEnv?.pluginBaseUrl ? searchMode : DatasetSearchModeEnum.embedding;
// Compatible with topk limit
if (maxTokens < 50) {
maxTokens = 1500;
}
const rerank =
global.reRankModels?.[0] &&
(searchMode === DatasetSearchModeEnum.embeddingReRank ||
searchMode === DatasetSearchModeEnum.embFullTextReRank);
let set = new Set<string>();
const oneChunkToken = 50;
const { embeddingLimit, fullTextLimit } = (() => {
/* function */
const countRecallLimit = () => {
const oneChunkToken = 50;
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
// Increase search range, reduce hnsw loss. 20 ~ 100
@@ -148,34 +155,295 @@ export async function searchDatasetData(props: SearchProps) {
embeddingLimit: Math.min(80, Math.max(50, estimatedLen * 2)),
fullTextLimit: Math.min(40, Math.max(20, estimatedLen))
};
})();
};
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [query]
});
const [{ tokenLen, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
embeddingRecall({
...props,
rerank,
limit: embeddingLimit
}),
fullTextRecall({
...props,
limit: fullTextLimit
})
]);
const results: any = await PgClient.query(
`BEGIN;
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
select id, collection_id, data_id, (vector <#> '[${vectors[0]}]') * -1 AS score
from ${PgDatasetTableName}
where dataset_id IN (${datasetIds.map((id) => `'${String(id)}'`).join(',')})
${rerank ? '' : `AND vector <#> '[${vectors[0]}]' < -${similarity}`}
order by score desc limit ${limit};
COMMIT;`
);
// concat embedding and fullText recall result
let set = new Set<string>(embeddingRecallResults.map((item) => item.id));
const concatRecallResults = embeddingRecallResults;
fullTextRecallResults.forEach((item) => {
if (!set.has(item.id) && item.score >= similarity) {
concatRecallResults.push(item);
set.add(item.id);
const rows = results?.[2]?.rows as PgSearchRawType[];
// concat same data_id
const filterRows: PgSearchRawType[] = [];
let set = new Set<string>();
for (const row of rows) {
if (!set.has(row.data_id)) {
filterRows.push(row);
set.add(row.data_id);
}
}
// get q and a
const [collections, dataList] = await Promise.all([
MongoDatasetCollection.find(
{
_id: { $in: filterRows.map((item) => item.collection_id) }
},
'name fileId rawLink'
).lean(),
MongoDatasetData.find(
{
_id: { $in: filterRows.map((item) => item.data_id?.trim()) }
},
'datasetId collectionId q a chunkIndex indexes'
).lean()
]);
const formatResult = filterRows
.map((item) => {
const collection = collections.find(
(collection) => String(collection._id) === item.collection_id
);
const data = dataList.find((data) => String(data._id) === item.data_id);
// if collection or data UnExist, the relational mongo data already deleted
if (!collection || !data) return null;
return {
id: String(data._id),
q: data.q,
a: data.a,
chunkIndex: data.chunkIndex,
indexes: data.indexes,
datasetId: String(data.datasetId),
collectionId: String(data.collectionId),
sourceName: collection.name || '',
sourceId: collection?.fileId || collection?.rawLink,
score: item.score
};
})
.filter((item) => item !== null) as SearchDataResponseItemType[];
return {
embeddingRecallResults: formatResult,
tokenLen
};
};
const fullTextRecall = async ({
query,
limit
}: {
query: string;
limit: number;
}): Promise<{
fullTextRecallResults: SearchDataResponseItemType[];
tokenLen: number;
}> => {
if (limit === 0) {
return {
fullTextRecallResults: [],
tokenLen: 0
};
}
let searchResults = (
await Promise.all(
datasetIds.map((id) =>
MongoDatasetData.find(
{
datasetId: id,
$text: { $search: jiebaSplit({ text: query }) }
},
{
score: { $meta: 'textScore' },
_id: 1,
datasetId: 1,
collectionId: 1,
q: 1,
a: 1,
indexes: 1,
chunkIndex: 1
}
)
.sort({ score: { $meta: 'textScore' } })
.limit(limit)
.lean()
)
)
).flat() as (DatasetDataSchemaType & { score: number })[];
// resort
searchResults.sort((a, b) => b.score - a.score);
searchResults.slice(0, limit);
const collections = await MongoDatasetCollection.find(
{
_id: { $in: searchResults.map((item) => item.collectionId) }
},
'_id name fileId rawLink'
);
return {
fullTextRecallResults: searchResults.map((item) => {
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
return {
id: String(item._id),
datasetId: String(item.datasetId),
collectionId: String(item.collectionId),
sourceName: collection?.name || '',
sourceId: collection?.fileId || collection?.rawLink,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex,
indexes: item.indexes,
// @ts-ignore
score: item.score
};
}),
tokenLen: 0
};
};
const reRankSearchResult = async ({
data,
query
}: {
data: SearchDataResponseItemType[];
query: string;
}): Promise<SearchDataResponseItemType[]> => {
try {
const results = await reRankRecall({
query,
inputs: data.map((item) => ({
id: item.id,
text: `${item.q}\n${item.a}`
}))
});
if (!Array.isArray(results)) return data;
// add new score to data
const mergeResult = results
.map((item) => {
const target = data.find((dataItem) => dataItem.id === item.id);
if (!target) return null;
return {
...target,
score: item.score || target.score
};
})
.filter(Boolean) as SearchDataResponseItemType[];
return mergeResult;
} catch (error) {
return data;
}
};
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
const results: SearchDataResponseItemType[] = [];
let totalTokens = 0;
for (let i = 0; i < list.length; i++) {
const item = list[i];
totalTokens += countPromptTokens(item.q + item.a);
if (totalTokens > maxTokens + 500) {
break;
}
results.push(item);
if (totalTokens > maxTokens) {
break;
}
}
return results.length === 0 ? list.slice(0, 1) : results;
};
const multiQueryRecall = async ({
embeddingLimit,
fullTextLimit
}: {
embeddingLimit: number;
fullTextLimit: number;
}) => {
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
minAmount = Math.min(resultList.length, minAmount);
const map: Record<
string,
{
amount: number;
data: SearchDataResponseItemType;
}
> = {};
for (const list of resultList) {
for (const item of list) {
map[item.id] = map[item.id]
? {
amount: map[item.id].amount + 1,
data: item
}
: {
amount: 1,
data: item
};
}
}
return Object.values(map)
.filter((item) => item.amount >= minAmount)
.map((item) => item.data);
};
// multi query recall
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
let embTokens = 0;
for await (const query of queries) {
const [{ tokenLen, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
embeddingRecall({
query,
limit: embeddingLimit
}),
fullTextRecall({
query,
limit: fullTextLimit
})
]);
embTokens += tokenLen;
embeddingRecallResList.push(embeddingRecallResults);
fullTextRecallResList.push(fullTextRecallResults);
}
return {
tokens: embTokens,
embeddingRecallResults: getIntersection(embeddingRecallResList, 2),
fullTextRecallResults: getIntersection(fullTextRecallResList, 2)
};
};
/* main step */
// count limit
const { embeddingLimit, fullTextLimit } = countRecallLimit();
// recall
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
embeddingLimit,
fullTextLimit
});
// concat recall results
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
const concatRecallResults = embeddingRecallResults.concat(
fullTextRecallResults.filter((item) => !set.has(item.id))
);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = concatRecallResults.filter((item) => {
const str = `${item.q}${item.a}`.trim();
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
@@ -187,14 +455,14 @@ export async function searchDatasetData(props: SearchProps) {
filterSameDataResults.filter((item) => item.score >= similarity),
maxTokens
),
tokenLen
tokenLen: tokens
};
}
// ReRank result
// ReRank results
const reRankResults = (
await reRankSearchResult({
query: text,
query: rawQuery,
data: filterSameDataResults
})
).filter((item) => item.score > similarity);
@@ -204,210 +472,7 @@ export async function searchDatasetData(props: SearchProps) {
reRankResults.filter((item) => item.score >= similarity),
maxTokens
),
tokenLen
tokenLen: tokens
};
}
export async function embeddingRecall({
text,
model,
similarity = 0,
limit,
datasetIds = [],
rerank = false
}: SearchProps & { rerank: boolean }) {
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [text]
});
const results: any = await PgClient.query(
`BEGIN;
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
select id, collection_id, data_id, (vector <#> '[${vectors[0]}]') * -1 AS score
from ${PgDatasetTableName}
where dataset_id IN (${datasetIds.map((id) => `'${String(id)}'`).join(',')})
${rerank ? '' : `AND vector <#> '[${vectors[0]}]' < -${similarity}`}
order by score desc limit ${limit};
COMMIT;`
);
const rows = results?.[2]?.rows as PgSearchRawType[];
// concat same data_id
const filterRows: PgSearchRawType[] = [];
let set = new Set<string>();
for (const row of rows) {
if (!set.has(row.data_id)) {
filterRows.push(row);
set.add(row.data_id);
}
}
// get q and a
const [collections, dataList] = await Promise.all([
MongoDatasetCollection.find(
{
_id: { $in: filterRows.map((item) => item.collection_id) }
},
'name fileId rawLink'
).lean(),
MongoDatasetData.find(
{
_id: { $in: filterRows.map((item) => item.data_id?.trim()) }
},
'datasetId collectionId q a chunkIndex indexes'
).lean()
]);
const formatResult = filterRows
.map((item) => {
const collection = collections.find(
(collection) => String(collection._id) === item.collection_id
);
const data = dataList.find((data) => String(data._id) === item.data_id);
// if collection or data UnExist, the relational mongo data already deleted
if (!collection || !data) return null;
return {
id: String(data._id),
q: data.q,
a: data.a,
chunkIndex: data.chunkIndex,
indexes: data.indexes,
datasetId: String(data.datasetId),
collectionId: String(data.collectionId),
sourceName: collection.name || '',
sourceId: collection?.fileId || collection?.rawLink,
score: item.score
};
})
.filter((item) => item !== null) as SearchDataResponseItemType[];
return {
embeddingRecallResults: formatResult,
tokenLen
};
}
export async function fullTextRecall({ text, limit, datasetIds = [] }: SearchProps): Promise<{
fullTextRecallResults: SearchDataResponseItemType[];
tokenLen: number;
}> {
if (limit === 0) {
return {
fullTextRecallResults: [],
tokenLen: 0
};
}
let searchResults = (
await Promise.all(
datasetIds.map((id) =>
MongoDatasetData.find(
{
datasetId: id,
$text: { $search: jiebaSplit({ text }) }
},
{
score: { $meta: 'textScore' },
_id: 1,
datasetId: 1,
collectionId: 1,
q: 1,
a: 1,
indexes: 1,
chunkIndex: 1
}
)
.sort({ score: { $meta: 'textScore' } })
.limit(limit)
.lean()
)
)
).flat() as (DatasetDataSchemaType & { score: number })[];
// resort
searchResults.sort((a, b) => b.score - a.score);
searchResults.slice(0, limit);
const collections = await MongoDatasetCollection.find(
{
_id: { $in: searchResults.map((item) => item.collectionId) }
},
'_id name fileId rawLink'
);
return {
fullTextRecallResults: searchResults.map((item) => {
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
return {
id: String(item._id),
datasetId: String(item.datasetId),
collectionId: String(item.collectionId),
sourceName: collection?.name || '',
sourceId: collection?.fileId || collection?.rawLink,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex,
indexes: item.indexes,
// @ts-ignore
score: item.score
};
}),
tokenLen: 0
};
}
// plus reRank search result
export async function reRankSearchResult({
data,
query
}: {
data: SearchDataResponseItemType[];
query: string;
}): Promise<SearchDataResponseItemType[]> {
try {
const results = await reRankRecall({
query,
inputs: data.map((item) => ({
id: item.id,
text: `${item.q}\n${item.a}`
}))
});
if (!Array.isArray(results)) return data;
// add new score to data
const mergeResult = results
.map((item) => {
const target = data.find((dataItem) => dataItem.id === item.id);
if (!target) return null;
return {
...target,
score: item.score || target.score
};
})
.filter(Boolean) as SearchDataResponseItemType[];
return mergeResult;
} catch (error) {
return data;
}
}
export function filterResultsByMaxTokens(list: SearchDataResponseItemType[], maxTokens: number) {
const results: SearchDataResponseItemType[] = [];
let totalTokens = 0;
for (let i = 0; i < list.length; i++) {
const item = list[i];
totalTokens += countPromptTokens(item.q + item.a);
if (totalTokens > maxTokens + 200) {
break;
}
results.push(item);
if (totalTokens > maxTokens) {
break;
}
}
return results;
}
// ------------------ search end ------------------