This commit is contained in:
Archer
2023-12-31 14:12:51 +08:00
committed by GitHub
parent ccca0468da
commit 9ccfda47b7
270 changed files with 8182 additions and 1295 deletions

View File

@@ -0,0 +1,13 @@
import { cut } from '@node-rs/jieba';
import { stopWords } from '@fastgpt/global/common/string/jieba';
export function jiebaSplit({ text }: { text: string }) {
const tokens = cut(text, true);
return (
tokens
.map((item) => item.replace(/[^\u4e00-\u9fa5a-zA-Z0-9\s]/g, '').trim())
.filter((item) => item && !stopWords.has(item))
.join(' ') || ''
);
}

View File

@@ -24,13 +24,24 @@ export function getAudioSpeechModel(model?: string) {
);
}
export function getWhisperModel(model?: string) {
return global.whisperModel;
}
export function getReRankModel(model?: string) {
return global.reRankModels.find((item) => item.model === model);
}
export enum ModelTypeEnum {
chat = 'chat',
qa = 'qa',
cq = 'cq',
extract = 'extract',
qg = 'qg',
vector = 'vector'
vector = 'vector',
audioSpeech = 'audioSpeech',
whisper = 'whisper',
rerank = 'rerank'
}
export const getModelMap = {
[ModelTypeEnum.chat]: getChatModel,
@@ -38,5 +49,8 @@ export const getModelMap = {
[ModelTypeEnum.cq]: getCQModel,
[ModelTypeEnum.extract]: getExtractModel,
[ModelTypeEnum.qg]: getQGModel,
[ModelTypeEnum.vector]: getVectorModel
[ModelTypeEnum.vector]: getVectorModel,
[ModelTypeEnum.audioSpeech]: getAudioSpeechModel,
[ModelTypeEnum.whisper]: getWhisperModel,
[ModelTypeEnum.rerank]: getReRankModel
};

View File

@@ -20,8 +20,14 @@ export function reRankRecall({ query, inputs }: PostReRankProps) {
Authorization: `Bearer ${model.requestAuth}`
}
}
).then((data) => {
console.log('rerank time:', Date.now() - start);
return data;
});
)
.then((data) => {
console.log('rerank time:', Date.now() - start);
return data;
})
.catch((err) => {
console.log(err);
return [];
});
}

View File

@@ -1,73 +0,0 @@
import { getAIApi } from '@fastgpt/service/core/ai/config';
export type GetVectorProps = {
model: string;
input: string | string[];
};
// text to vector
export async function getVectorsByText({
model = 'text-embedding-ada-002',
input
}: GetVectorProps) {
try {
if (typeof input === 'string' && !input) {
return Promise.reject({
code: 500,
message: 'input is empty'
});
} else if (Array.isArray(input)) {
for (let i = 0; i < input.length; i++) {
if (!input[i]) {
return Promise.reject({
code: 500,
message: 'input array is empty'
});
}
}
}
// 获取 chatAPI
const ai = getAIApi();
// 把输入的内容转成向量
const result = await ai.embeddings
.create({
model,
input
})
.then(async (res) => {
if (!res.data) {
return Promise.reject('Embedding API 404');
}
if (!res?.data?.[0]?.embedding) {
console.log(res?.data);
// @ts-ignore
return Promise.reject(res.data?.err?.message || 'Embedding API Error');
}
return {
tokenLen: res.usage.total_tokens || 0,
vectors: await Promise.all(res.data.map((item) => unityDimensional(item.embedding)))
};
});
return result;
} catch (error) {
console.log(`Embedding Error`, error);
return Promise.reject(error);
}
}
function unityDimensional(vector: number[]) {
if (vector.length > 1536) {
console.log(`当前向量维度为: ${vector.length}, 向量维度不能超过 1536, 已自动截取前 1536 维度`);
return vector.slice(0, 1536);
}
let resultVector = vector;
const vectorLen = vector.length;
const zeroVector = new Array(1536 - vectorLen).fill(0);
return resultVector.concat(zeroVector);
}

View File

@@ -4,12 +4,30 @@ import {
PatchIndexesProps,
UpdateDatasetDataProps
} from '@fastgpt/global/core/dataset/controller';
import { deletePgDataById } from '@fastgpt/service/core/dataset/data/pg';
import { insertData2Pg, updatePgDataById } from './pg';
import {
insertDatasetDataVector,
recallFromVectorStore,
updateDatasetDataVector
} from '@fastgpt/service/common/vectorStore/controller';
import { Types } from 'mongoose';
import { DatasetDataIndexTypeEnum } from '@fastgpt/global/core/dataset/constant';
import {
DatasetDataIndexTypeEnum,
DatasetSearchModeEnum,
DatasetSearchModeMap,
SearchScoreTypeEnum
} from '@fastgpt/global/core/dataset/constant';
import { getDefaultIndex } from '@fastgpt/global/core/dataset/utils';
import { jiebaSplit } from '../utils';
import { jiebaSplit } from '@/service/common/string/jieba';
import { deleteDatasetDataVector } from '@fastgpt/service/common/vectorStore/controller';
import { getVectorsByText } from '@fastgpt/service/core/ai/embedding';
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
import {
DatasetDataSchemaType,
SearchDataResponseItemType
} from '@fastgpt/global/core/dataset/type';
import { reRankRecall } from '../../ai/rerank';
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
import { hashStr } from '@fastgpt/global/common/string/tools';
/* insert data.
* 1. create data id
@@ -50,17 +68,17 @@ export async function insertData2Dataset({
}))
: [getDefaultIndex({ q, a })];
// insert to pg
// insert to vector store
const result = await Promise.all(
indexes.map((item) =>
insertData2Pg({
mongoDataId: String(id),
input: item.text,
insertDatasetDataVector({
query: item.text,
model,
teamId,
tmbId,
datasetId,
collectionId
collectionId,
dataId: String(id)
})
)
);
@@ -84,7 +102,7 @@ export async function insertData2Dataset({
return {
insertId: _id,
tokenLen: result.reduce((acc, cur) => acc + cur.tokenLen, 0)
tokens: result.reduce((acc, cur) => acc + cur.tokens, 0)
};
}
@@ -172,35 +190,40 @@ export async function updateData2Dataset({
const result = await Promise.all(
patchResult.map(async (item) => {
if (item.type === 'create') {
const result = await insertData2Pg({
mongoDataId: dataId,
input: item.index.text,
const result = await insertDatasetDataVector({
query: item.index.text,
model,
teamId: mongoData.teamId,
tmbId: mongoData.tmbId,
datasetId: mongoData.datasetId,
collectionId: mongoData.collectionId
collectionId: mongoData.collectionId,
dataId
});
item.index.dataId = result.insertId;
return result;
}
if (item.type === 'update' && item.index.dataId) {
return updatePgDataById({
return updateDatasetDataVector({
id: item.index.dataId,
input: item.index.text,
query: item.index.text,
model
});
}
if (item.type === 'delete' && item.index.dataId) {
return deletePgDataById(['id', item.index.dataId]);
await deleteDatasetDataVector({
id: item.index.dataId
});
return {
tokens: 0
};
}
return {
tokenLen: 0
tokens: 0
};
})
);
const tokenLen = result.reduce((acc, cur) => acc + cur.tokenLen, 0);
const tokens = result.reduce((acc, cur) => acc + cur.tokens, 0);
// update mongo
mongoData.q = q || mongoData.q;
@@ -211,6 +234,457 @@ export async function updateData2Dataset({
await mongoData.save();
return {
tokenLen
tokens
};
}
export async function searchDatasetData(props: {
model: string;
similarity?: number; // min distance
limit: number; // max Token limit
datasetIds: string[];
searchMode?: `${DatasetSearchModeEnum}`;
usingReRank?: boolean;
rawQuery: string;
queries: string[];
}) {
let {
rawQuery,
queries,
model,
similarity = 0,
limit: maxTokens,
searchMode = DatasetSearchModeEnum.embedding,
usingReRank = false,
datasetIds = []
} = props;
/* init params */
searchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding;
usingReRank = usingReRank && global.reRankModels.length > 0;
// Compatible with topk limit
if (maxTokens < 50) {
maxTokens = 1500;
}
let set = new Set<string>();
let usingSimilarityFilter = false;
/* function */
const countRecallLimit = () => {
const oneChunkToken = 50;
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
// Increase search range, reduce hnsw loss. 20 ~ 100
if (searchMode === DatasetSearchModeEnum.embedding) {
return {
embeddingLimit: Math.min(estimatedLen, 100),
fullTextLimit: 0
};
}
// 50 < 2*limit < value < 100
if (searchMode === DatasetSearchModeEnum.fullTextRecall) {
return {
embeddingLimit: 0,
fullTextLimit: Math.min(estimatedLen, 50)
};
}
// mixed
// 50 < 2*limit < embedding < 80
// 20 < limit < fullTextLimit < 40
return {
embeddingLimit: Math.min(estimatedLen, 80),
fullTextLimit: Math.min(estimatedLen, 40)
};
};
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
const { vectors, tokens } = await getVectorsByText({
model,
input: [query]
});
const { results } = await recallFromVectorStore({
vectors,
limit,
datasetIds
});
// get q and a
const [collections, dataList] = await Promise.all([
MongoDatasetCollection.find(
{
_id: { $in: results.map((item) => item.collectionId) }
},
'name fileId rawLink'
).lean(),
MongoDatasetData.find(
{
_id: { $in: results.map((item) => item.dataId?.trim()) }
},
'datasetId collectionId q a chunkIndex indexes'
).lean()
]);
const formatResult = results
.map((item, index) => {
const collection = collections.find(
(collection) => String(collection._id) === item.collectionId
);
const data = dataList.find((data) => String(data._id) === item.dataId);
// if collection or data UnExist, the relational mongo data already deleted
if (!collection || !data) return null;
const result: SearchDataResponseItemType = {
id: String(data._id),
q: data.q,
a: data.a,
chunkIndex: data.chunkIndex,
indexes: data.indexes,
datasetId: String(data.datasetId),
collectionId: String(data.collectionId),
sourceName: collection.name || '',
sourceId: collection?.fileId || collection?.rawLink,
score: [{ type: SearchScoreTypeEnum.embedding, value: item.score, index }]
};
return result;
})
.filter((item) => item !== null) as SearchDataResponseItemType[];
return {
embeddingRecallResults: formatResult,
tokens
};
};
const fullTextRecall = async ({
query,
limit
}: {
query: string;
limit: number;
}): Promise<{
fullTextRecallResults: SearchDataResponseItemType[];
tokenLen: number;
}> => {
if (limit === 0) {
return {
fullTextRecallResults: [],
tokenLen: 0
};
}
let searchResults = (
await Promise.all(
datasetIds.map((id) =>
MongoDatasetData.find(
{
datasetId: id,
$text: { $search: jiebaSplit({ text: query }) }
},
{
score: { $meta: 'textScore' },
_id: 1,
datasetId: 1,
collectionId: 1,
q: 1,
a: 1,
indexes: 1,
chunkIndex: 1
}
)
.sort({ score: { $meta: 'textScore' } })
.limit(limit)
.lean()
)
)
).flat() as (DatasetDataSchemaType & { score: number })[];
// resort
searchResults.sort((a, b) => b.score - a.score);
searchResults.slice(0, limit);
const collections = await MongoDatasetCollection.find(
{
_id: { $in: searchResults.map((item) => item.collectionId) }
},
'_id name fileId rawLink'
);
return {
fullTextRecallResults: searchResults.map((item, index) => {
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
return {
id: String(item._id),
datasetId: String(item.datasetId),
collectionId: String(item.collectionId),
sourceName: collection?.name || '',
sourceId: collection?.fileId || collection?.rawLink,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex,
indexes: item.indexes,
score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }]
};
}),
tokenLen: 0
};
};
const reRankSearchResult = async ({
data,
query
}: {
data: SearchDataResponseItemType[];
query: string;
}): Promise<SearchDataResponseItemType[]> => {
try {
const results = await reRankRecall({
query,
inputs: data.map((item) => ({
id: item.id,
text: `${item.q}\n${item.a}`
}))
});
if (!Array.isArray(results)) return [];
// add new score to data
const mergeResult = results
.map((item, index) => {
const target = data.find((dataItem) => dataItem.id === item.id);
if (!target) return null;
const score = item.score || 0;
return {
...target,
score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }]
};
})
.filter(Boolean) as SearchDataResponseItemType[];
return mergeResult;
} catch (error) {
return [];
}
};
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
const results: SearchDataResponseItemType[] = [];
let totalTokens = 0;
for (let i = 0; i < list.length; i++) {
const item = list[i];
totalTokens += countPromptTokens(item.q + item.a);
if (totalTokens > maxTokens + 500) {
break;
}
results.push(item);
if (totalTokens > maxTokens) {
break;
}
}
return results.length === 0 ? list.slice(0, 1) : results;
};
const multiQueryRecall = async ({
embeddingLimit,
fullTextLimit
}: {
embeddingLimit: number;
fullTextLimit: number;
}) => {
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
minAmount = Math.min(resultList.length, minAmount);
const map: Record<
string,
{
amount: number;
data: SearchDataResponseItemType;
}
> = {};
for (const list of resultList) {
for (const item of list) {
map[item.id] = map[item.id]
? {
amount: map[item.id].amount + 1,
data: item
}
: {
amount: 1,
data: item
};
}
}
return Object.values(map)
.filter((item) => item.amount >= minAmount)
.map((item) => item.data);
};
// multi query recall
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
let embTokens = 0;
for await (const query of queries) {
const [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
embeddingRecall({
query,
limit: embeddingLimit
}),
fullTextRecall({
query,
limit: fullTextLimit
})
]);
embTokens += tokens;
embeddingRecallResList.push(embeddingRecallResults);
fullTextRecallResList.push(fullTextRecallResults);
}
return {
tokens: embTokens,
embeddingRecallResults: embeddingRecallResList[0],
fullTextRecallResults: fullTextRecallResList[0]
};
};
const rrfConcat = (
arr: { k: number; list: SearchDataResponseItemType[] }[]
): SearchDataResponseItemType[] => {
const map = new Map<string, SearchDataResponseItemType & { rrfScore: number }>();
// rrf
arr.forEach((item) => {
const k = item.k;
item.list.forEach((data, index) => {
const rank = index + 1;
const score = 1 / (k + rank);
const record = map.get(data.id);
if (record) {
// 合并两个score,有相同type的score,取最大值
const concatScore = [...record.score];
for (const dataItem of data.score) {
const sameScore = concatScore.find((item) => item.type === dataItem.type);
if (sameScore) {
sameScore.value = Math.max(sameScore.value, dataItem.value);
} else {
concatScore.push(dataItem);
}
}
map.set(data.id, {
...record,
score: concatScore,
rrfScore: record.rrfScore + score
});
} else {
map.set(data.id, {
...data,
rrfScore: score
});
}
});
});
// sort
const mapArray = Array.from(map.values());
const results = mapArray.sort((a, b) => b.rrfScore - a.rrfScore);
return results.map((item, index) => {
item.score.push({
type: SearchScoreTypeEnum.rrf,
value: item.rrfScore,
index
});
// @ts-ignore
delete item.rrfScore;
return item;
});
};
/* main step */
// count limit
const { embeddingLimit, fullTextLimit } = countRecallLimit();
// recall
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
embeddingLimit,
fullTextLimit
});
// ReRank results
const reRankResults = await (async () => {
if (!usingReRank) return [];
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
const concatRecallResults = embeddingRecallResults.concat(
fullTextRecallResults.filter((item) => !set.has(item.id))
);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = concatRecallResults.filter((item) => {
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
});
return reRankSearchResult({
query: rawQuery,
data: filterSameDataResults
});
})();
// embedding recall and fullText recall rrf concat
const rrfConcatResults = rrfConcat([
{ k: 60, list: embeddingRecallResults },
{ k: 60, list: fullTextRecallResults },
{ k: 60, list: reRankResults }
]);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = rrfConcatResults.filter((item) => {
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
});
// score filter
const scoreFilter = (() => {
if (usingReRank) {
usingSimilarityFilter = true;
return filterSameDataResults.filter((item) => {
const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank);
if (reRankScore && reRankScore.value < similarity) return false;
return true;
});
}
if (searchMode === DatasetSearchModeEnum.embedding) {
return filterSameDataResults.filter((item) => {
usingSimilarityFilter = true;
const embeddingScore = item.score.find(
(item) => item.type === SearchScoreTypeEnum.embedding
);
if (embeddingScore && embeddingScore.value < similarity) return false;
return true;
});
}
return filterSameDataResults;
})();
return {
searchRes: filterResultsByMaxTokens(scoreFilter, maxTokens),
tokens,
usingSimilarityFilter
};
}

View File

@@ -1,478 +0,0 @@
import { DatasetSearchModeEnum, PgDatasetTableName } from '@fastgpt/global/core/dataset/constant';
import type {
DatasetDataSchemaType,
SearchDataResponseItemType
} from '@fastgpt/global/core/dataset/type.d';
import { PgClient } from '@fastgpt/service/common/pg';
import { getVectorsByText } from '@/service/core/ai/vector';
import { delay } from '@fastgpt/global/common/system/utils';
import { PgSearchRawType } from '@fastgpt/global/core/dataset/api';
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
import { jiebaSplit } from '../utils';
import { reRankRecall } from '../../ai/rerank';
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
import { hashStr } from '@fastgpt/global/common/string/tools';
export async function insertData2Pg(props: {
mongoDataId: string;
input: string;
model: string;
teamId: string;
tmbId: string;
datasetId: string;
collectionId: string;
retry?: number;
}): Promise<{ insertId: string; vectors: number[][]; tokenLen: number }> {
const { mongoDataId, input, model, teamId, tmbId, datasetId, collectionId, retry = 3 } = props;
try {
// get vector
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [input]
});
const { rows } = await PgClient.insert(PgDatasetTableName, {
values: [
[
{ key: 'vector', value: `[${vectors[0]}]` },
{ key: 'team_id', value: String(teamId) },
{ key: 'tmb_id', value: String(tmbId) },
{ key: 'dataset_id', value: datasetId },
{ key: 'collection_id', value: collectionId },
{ key: 'data_id', value: String(mongoDataId) }
]
]
});
return {
insertId: rows[0].id,
vectors,
tokenLen
};
} catch (error) {
if (retry <= 0) {
return Promise.reject(error);
}
await delay(500);
return insertData2Pg({
...props,
retry: retry - 1
});
}
}
export async function updatePgDataById({
id,
input,
model
}: {
id: string;
input: string;
model: string;
}) {
let retry = 2;
async function updatePg(): Promise<{ vectors: number[][]; tokenLen: number }> {
try {
// get vector
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [input]
});
// update pg
await PgClient.update(PgDatasetTableName, {
where: [['id', id]],
values: [{ key: 'vector', value: `[${vectors[0]}]` }]
});
return {
vectors,
tokenLen
};
} catch (error) {
if (--retry < 0) {
return Promise.reject(error);
}
await delay(500);
return updatePg();
}
}
return updatePg();
}
// ------------------ search start ------------------
type SearchProps = {
model: string;
similarity?: number; // min distance
limit: number; // max Token limit
datasetIds: string[];
searchMode?: `${DatasetSearchModeEnum}`;
};
export async function searchDatasetData(
props: SearchProps & { rawQuery: string; queries: string[] }
) {
let {
rawQuery,
queries,
model,
similarity = 0,
limit: maxTokens,
searchMode = DatasetSearchModeEnum.embedding,
datasetIds = []
} = props;
/* init params */
searchMode = global.systemEnv?.pluginBaseUrl ? searchMode : DatasetSearchModeEnum.embedding;
// Compatible with topk limit
if (maxTokens < 50) {
maxTokens = 1500;
}
const rerank =
global.reRankModels?.[0] &&
(searchMode === DatasetSearchModeEnum.embeddingReRank ||
searchMode === DatasetSearchModeEnum.embFullTextReRank);
let set = new Set<string>();
/* function */
const countRecallLimit = () => {
const oneChunkToken = 50;
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
// Increase search range, reduce hnsw loss. 20 ~ 100
if (searchMode === DatasetSearchModeEnum.embedding) {
return {
embeddingLimit: Math.min(estimatedLen, 100),
fullTextLimit: 0
};
}
// 50 < 2*limit < value < 100
if (searchMode === DatasetSearchModeEnum.embeddingReRank) {
return {
embeddingLimit: Math.min(100, Math.max(50, estimatedLen * 2)),
fullTextLimit: 0
};
}
// 50 < 2*limit < embedding < 80
// 20 < limit < fullTextLimit < 40
return {
embeddingLimit: Math.min(80, Math.max(50, estimatedLen * 2)),
fullTextLimit: Math.min(40, Math.max(20, estimatedLen))
};
};
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [query]
});
const results: any = await PgClient.query(
`BEGIN;
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
select id, collection_id, data_id, (vector <#> '[${vectors[0]}]') * -1 AS score
from ${PgDatasetTableName}
where dataset_id IN (${datasetIds.map((id) => `'${String(id)}'`).join(',')})
${rerank ? '' : `AND vector <#> '[${vectors[0]}]' < -${similarity}`}
order by score desc limit ${limit};
COMMIT;`
);
const rows = results?.[2]?.rows as PgSearchRawType[];
// concat same data_id
const filterRows: PgSearchRawType[] = [];
let set = new Set<string>();
for (const row of rows) {
if (!set.has(row.data_id)) {
filterRows.push(row);
set.add(row.data_id);
}
}
// get q and a
const [collections, dataList] = await Promise.all([
MongoDatasetCollection.find(
{
_id: { $in: filterRows.map((item) => item.collection_id) }
},
'name fileId rawLink'
).lean(),
MongoDatasetData.find(
{
_id: { $in: filterRows.map((item) => item.data_id?.trim()) }
},
'datasetId collectionId q a chunkIndex indexes'
).lean()
]);
const formatResult = filterRows
.map((item) => {
const collection = collections.find(
(collection) => String(collection._id) === item.collection_id
);
const data = dataList.find((data) => String(data._id) === item.data_id);
// if collection or data UnExist, the relational mongo data already deleted
if (!collection || !data) return null;
return {
id: String(data._id),
q: data.q,
a: data.a,
chunkIndex: data.chunkIndex,
indexes: data.indexes,
datasetId: String(data.datasetId),
collectionId: String(data.collectionId),
sourceName: collection.name || '',
sourceId: collection?.fileId || collection?.rawLink,
score: item.score
};
})
.filter((item) => item !== null) as SearchDataResponseItemType[];
return {
embeddingRecallResults: formatResult,
tokenLen
};
};
const fullTextRecall = async ({
query,
limit
}: {
query: string;
limit: number;
}): Promise<{
fullTextRecallResults: SearchDataResponseItemType[];
tokenLen: number;
}> => {
if (limit === 0) {
return {
fullTextRecallResults: [],
tokenLen: 0
};
}
let searchResults = (
await Promise.all(
datasetIds.map((id) =>
MongoDatasetData.find(
{
datasetId: id,
$text: { $search: jiebaSplit({ text: query }) }
},
{
score: { $meta: 'textScore' },
_id: 1,
datasetId: 1,
collectionId: 1,
q: 1,
a: 1,
indexes: 1,
chunkIndex: 1
}
)
.sort({ score: { $meta: 'textScore' } })
.limit(limit)
.lean()
)
)
).flat() as (DatasetDataSchemaType & { score: number })[];
// resort
searchResults.sort((a, b) => b.score - a.score);
searchResults.slice(0, limit);
const collections = await MongoDatasetCollection.find(
{
_id: { $in: searchResults.map((item) => item.collectionId) }
},
'_id name fileId rawLink'
);
return {
fullTextRecallResults: searchResults.map((item) => {
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
return {
id: String(item._id),
datasetId: String(item.datasetId),
collectionId: String(item.collectionId),
sourceName: collection?.name || '',
sourceId: collection?.fileId || collection?.rawLink,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex,
indexes: item.indexes,
// @ts-ignore
score: item.score
};
}),
tokenLen: 0
};
};
const reRankSearchResult = async ({
data,
query
}: {
data: SearchDataResponseItemType[];
query: string;
}): Promise<SearchDataResponseItemType[]> => {
try {
const results = await reRankRecall({
query,
inputs: data.map((item) => ({
id: item.id,
text: `${item.q}\n${item.a}`
}))
});
if (!Array.isArray(results)) return data;
// add new score to data
const mergeResult = results
.map((item) => {
const target = data.find((dataItem) => dataItem.id === item.id);
if (!target) return null;
return {
...target,
score: item.score || target.score
};
})
.filter(Boolean) as SearchDataResponseItemType[];
return mergeResult;
} catch (error) {
return data;
}
};
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
const results: SearchDataResponseItemType[] = [];
let totalTokens = 0;
for (let i = 0; i < list.length; i++) {
const item = list[i];
totalTokens += countPromptTokens(item.q + item.a);
if (totalTokens > maxTokens + 500) {
break;
}
results.push(item);
if (totalTokens > maxTokens) {
break;
}
}
return results.length === 0 ? list.slice(0, 1) : results;
};
const multiQueryRecall = async ({
embeddingLimit,
fullTextLimit
}: {
embeddingLimit: number;
fullTextLimit: number;
}) => {
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
minAmount = Math.min(resultList.length, minAmount);
const map: Record<
string,
{
amount: number;
data: SearchDataResponseItemType;
}
> = {};
for (const list of resultList) {
for (const item of list) {
map[item.id] = map[item.id]
? {
amount: map[item.id].amount + 1,
data: item
}
: {
amount: 1,
data: item
};
}
}
return Object.values(map)
.filter((item) => item.amount >= minAmount)
.map((item) => item.data);
};
// multi query recall
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
let embTokens = 0;
for await (const query of queries) {
const [{ tokenLen, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
embeddingRecall({
query,
limit: embeddingLimit
}),
fullTextRecall({
query,
limit: fullTextLimit
})
]);
embTokens += tokenLen;
embeddingRecallResList.push(embeddingRecallResults);
fullTextRecallResList.push(fullTextRecallResults);
}
return {
tokens: embTokens,
embeddingRecallResults: getIntersection(embeddingRecallResList, 2),
fullTextRecallResults: getIntersection(fullTextRecallResList, 2)
};
};
/* main step */
// count limit
const { embeddingLimit, fullTextLimit } = countRecallLimit();
// recall
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
embeddingLimit,
fullTextLimit
});
// concat recall results
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
const concatRecallResults = embeddingRecallResults.concat(
fullTextRecallResults.filter((item) => !set.has(item.id))
);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = concatRecallResults.filter((item) => {
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
});
if (!rerank) {
return {
searchRes: filterResultsByMaxTokens(
filterSameDataResults.filter((item) => item.score >= similarity),
maxTokens
),
tokenLen: tokens
};
}
// ReRank results
const reRankResults = (
await reRankSearchResult({
query: rawQuery,
data: filterSameDataResults
})
).filter((item) => item.score > similarity);
return {
searchRes: filterResultsByMaxTokens(
reRankResults.filter((item) => item.score >= similarity),
maxTokens
),
tokenLen: tokens
};
}
// ------------------ search end ------------------

View File

@@ -1,6 +1,4 @@
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
import { cut } from '@node-rs/jieba';
import { stopWords } from '@fastgpt/global/common/string/jieba';
/**
* Same value judgment
@@ -24,14 +22,3 @@ export async function hasSameValue({
return Promise.reject('已经存在完全一致的数据');
}
}
export function jiebaSplit({ text }: { text: string }) {
const tokens = cut(text, true);
return (
tokens
.map((item) => item.replace(/[^\u4e00-\u9fa5a-zA-Z0-9\s]/g, '').trim())
.filter((item) => item && !stopWords.has(item))
.join(' ') || ''
);
}

View File

@@ -136,7 +136,6 @@ ${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;
stream: false
});
const answer = chatResponse.choices?.[0].message?.content || '';
const totalTokens = chatResponse.usage?.total_tokens || 0;
const qaArr = formatSplitText(answer, text); // 格式化后的QA对
@@ -167,7 +166,8 @@ ${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;
pushQABill({
teamId: data.teamId,
tmbId: data.tmbId,
totalTokens,
inputTokens: chatResponse.usage?.prompt_tokens || 0,
outputTokens: chatResponse.usage?.completion_tokens || 0,
billId: data.billId,
model
});

View File

@@ -129,7 +129,7 @@ export async function generateVector(): Promise<any> {
}
// insert data to pg
const { tokenLen } = await insertData2Dataset({
const { tokens } = await insertData2Dataset({
teamId: data.teamId,
tmbId: data.tmbId,
datasetId: data.datasetId,
@@ -145,7 +145,7 @@ export async function generateVector(): Promise<any> {
pushGenerateVectorBill({
teamId: data.teamId,
tmbId: data.tmbId,
tokenLen: tokenLen,
tokens,
model: data.model,
billId: data.billId
});

View File

@@ -9,8 +9,9 @@ import type { ModuleDispatchProps } from '@fastgpt/global/core/module/type.d';
import { replaceVariable } from '@fastgpt/global/common/string/tools';
import { Prompt_CQJson } from '@/global/core/prompt/agent';
import { FunctionModelItemType } from '@fastgpt/global/core/ai/model.d';
import { getCQModel } from '@/service/core/ai/model';
import { ModelTypeEnum, getCQModel } from '@/service/core/ai/model';
import { getHistories } from '../utils';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
type Props = ModuleDispatchProps<{
[ModuleInputKeyEnum.aiModel]: string;
@@ -42,7 +43,7 @@ export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse
const chatHistories = getHistories(history, histories);
const { arg, tokens } = await (async () => {
const { arg, inputTokens, outputTokens } = await (async () => {
if (cqModel.toolChoice) {
return toolChoice({
...props,
@@ -59,13 +60,21 @@ export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse
const result = agents.find((item) => item.key === arg?.type) || agents[agents.length - 1];
const { total, modelName } = formatModelPrice2Store({
model: cqModel.model,
inputLen: inputTokens,
outputLen: outputTokens,
type: ModelTypeEnum.cq
});
return {
[result.key]: result.value,
[ModuleOutputKeyEnum.responseData]: {
price: user.openaiAccount?.key ? 0 : cqModel.price * tokens,
model: cqModel.name || '',
price: user.openaiAccount?.key ? 0 : total,
model: modelName,
query: userChatInput,
tokens,
inputTokens,
outputTokens,
cqList: agents,
cqResult: result.value,
contextTotalLen: chatHistories.length + 2
@@ -140,7 +149,8 @@ ${systemPrompt}
return {
arg,
tokens: response.usage?.total_tokens || 0
inputTokens: response.usage?.prompt_tokens || 0,
outputTokens: response.usage?.completion_tokens || 0
};
} catch (error) {
console.log(agentFunction.parameters);
@@ -150,7 +160,8 @@ ${systemPrompt}
return {
arg: {},
tokens: 0
inputTokens: 0,
outputTokens: 0
};
}
}
@@ -182,12 +193,12 @@ Human:${userChatInput}`
stream: false
});
const answer = data.choices?.[0].message?.content || '';
const totalTokens = data.usage?.total_tokens || 0;
const id = agents.find((item) => answer.includes(item.key))?.key || '';
return {
tokens: totalTokens,
inputTokens: data.usage?.prompt_tokens || 0,
outputTokens: data.usage?.completion_tokens || 0,
arg: { type: id }
};
}

View File

@@ -10,7 +10,8 @@ import { Prompt_ExtractJson } from '@/global/core/prompt/agent';
import { replaceVariable } from '@fastgpt/global/common/string/tools';
import { FunctionModelItemType } from '@fastgpt/global/core/ai/model.d';
import { getHistories } from '../utils';
import { getExtractModel } from '@/service/core/ai/model';
import { ModelTypeEnum, getExtractModel } from '@/service/core/ai/model';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
type Props = ModuleDispatchProps<{
[ModuleInputKeyEnum.history]?: ChatItemType[];
@@ -42,7 +43,7 @@ export async function dispatchContentExtract(props: Props): Promise<Response> {
const extractModel = getExtractModel(model);
const chatHistories = getHistories(history, histories);
const { arg, tokens } = await (async () => {
const { arg, inputTokens, outputTokens } = await (async () => {
if (extractModel.toolChoice) {
return toolChoice({
...props,
@@ -79,16 +80,24 @@ export async function dispatchContentExtract(props: Props): Promise<Response> {
}
}
const { total, modelName } = formatModelPrice2Store({
model: extractModel.model,
inputLen: inputTokens,
outputLen: outputTokens,
type: ModelTypeEnum.extract
});
return {
[ModuleOutputKeyEnum.success]: success ? true : undefined,
[ModuleOutputKeyEnum.failed]: success ? undefined : true,
[ModuleOutputKeyEnum.contextExtractFields]: JSON.stringify(arg),
...arg,
[ModuleOutputKeyEnum.responseData]: {
price: user.openaiAccount?.key ? 0 : extractModel.price * tokens,
model: extractModel.name || '',
price: user.openaiAccount?.key ? 0 : total,
model: modelName,
query: content,
tokens,
inputTokens,
outputTokens,
extractDescription: description,
extractResult: arg,
contextTotalLen: chatHistories.length + 2
@@ -181,10 +190,10 @@ ${description || '根据用户要求获取适当的 JSON 字符串。'}
}
})();
const tokens = response.usage?.total_tokens || 0;
return {
rawResponse: response?.choices?.[0]?.message?.tool_calls?.[0]?.function?.arguments || '',
tokens,
inputTokens: response.usage?.prompt_tokens || 0,
outputTokens: response.usage?.completion_tokens || 0,
arg
};
}
@@ -223,7 +232,8 @@ Human: ${content}`
stream: false
});
const answer = data.choices?.[0].message?.content || '';
const totalTokens = data.usage?.total_tokens || 0;
const inputTokens = data.usage?.prompt_tokens || 0;
const outputTokens = data.usage?.completion_tokens || 0;
// parse response
const start = answer.indexOf('{');
@@ -232,7 +242,8 @@ Human: ${content}`
if (start === -1 || end === -1)
return {
rawResponse: answer,
tokens: totalTokens,
inputTokens,
outputTokens,
arg: {}
};
@@ -244,13 +255,15 @@ Human: ${content}`
try {
return {
rawResponse: answer,
tokens: totalTokens,
inputTokens,
outputTokens,
arg: JSON.parse(jsonStr) as Record<string, any>
};
} catch (error) {
return {
rawResponse: answer,
tokens: totalTokens,
inputTokens,
outputTokens,
arg: {}
};
}

View File

@@ -6,7 +6,7 @@ import { sseResponseEventEnum } from '@fastgpt/service/common/response/constant'
import { textAdaptGptResponse } from '@/utils/adapt';
import { getAIApi } from '@fastgpt/service/core/ai/config';
import type { ChatCompletion, StreamChatType } from '@fastgpt/global/core/ai/type.d';
import { countModelPrice } from '@/service/support/wallet/bill/utils';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
import type { ChatModelItemType } from '@fastgpt/global/core/ai/model.d';
import { postTextCensor } from '@/service/common/censor';
import { ChatCompletionRequestMessageRoleEnum } from '@fastgpt/global/core/ai/constant';
@@ -151,7 +151,7 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
}
);
const { answerText, totalTokens, completeMessages } = await (async () => {
const { answerText, inputTokens, outputTokens, completeMessages } = await (async () => {
if (stream) {
// sse response
const { answer } = await streamResponse({
@@ -165,21 +165,26 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
value: answer
});
const totalTokens = countMessagesTokens({
messages: completeMessages
});
targetResponse({ res, detail, outputs });
return {
answerText: answer,
totalTokens,
inputTokens: countMessagesTokens({
messages: filterMessages
}),
outputTokens: countMessagesTokens({
messages: [
{
obj: ChatRoleEnum.AI,
value: answer
}
]
}),
completeMessages
};
} else {
const unStreamResponse = response as ChatCompletion;
const answer = unStreamResponse.choices?.[0]?.message?.content || '';
const totalTokens = unStreamResponse.usage?.total_tokens || 0;
const completeMessages = filterMessages.concat({
obj: ChatRoleEnum.AI,
@@ -188,20 +193,27 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
return {
answerText: answer,
totalTokens,
inputTokens: unStreamResponse.usage?.prompt_tokens || 0,
outputTokens: unStreamResponse.usage?.completion_tokens || 0,
completeMessages
};
}
})();
const { total, modelName } = formatModelPrice2Store({
model,
inputLen: inputTokens,
outputLen: outputTokens,
type: ModelTypeEnum.chat
});
return {
answerText,
responseData: {
price: user.openaiAccount?.key
? 0
: countModelPrice({ model, tokens: totalTokens, type: ModelTypeEnum.chat }),
model: modelConstantsData.name,
tokens: totalTokens,
price: user.openaiAccount?.key ? 0 : total,
model: modelName,
inputTokens,
outputTokens,
query: userChatInput,
maxToken: max_tokens,
quoteList: filterQuoteQA,
@@ -227,8 +239,7 @@ function filterQuote({
a: item.a,
source: item.sourceName,
sourceId: String(item.sourceId || 'UnKnow'),
index: index + 1,
score: item.score?.toFixed(4)
index: index + 1
});
}

View File

@@ -1,13 +1,12 @@
import type { moduleDispatchResType } from '@fastgpt/global/core/chat/type.d';
import { countModelPrice } from '@/service/support/wallet/bill/utils';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
import type { SelectedDatasetType } from '@fastgpt/global/core/module/api.d';
import type { SearchDataResponseItemType } from '@fastgpt/global/core/dataset/type';
import type { ModuleDispatchProps } from '@fastgpt/global/core/module/type.d';
import { ModelTypeEnum } from '@/service/core/ai/model';
import { searchDatasetData } from '@/service/core/dataset/data/pg';
import { searchDatasetData } from '@/service/core/dataset/data/controller';
import { ModuleInputKeyEnum, ModuleOutputKeyEnum } from '@fastgpt/global/core/module/constants';
import { DatasetSearchModeEnum } from '@fastgpt/global/core/dataset/constant';
import { searchQueryExtension } from '@fastgpt/service/core/ai/functions/queryExtension';
type DatasetSearchProps = ModuleDispatchProps<{
[ModuleInputKeyEnum.datasetSelectList]: SelectedDatasetType;
@@ -15,6 +14,7 @@ type DatasetSearchProps = ModuleDispatchProps<{
[ModuleInputKeyEnum.datasetLimit]: number;
[ModuleInputKeyEnum.datasetSearchMode]: `${DatasetSearchModeEnum}`;
[ModuleInputKeyEnum.userChatInput]: string;
[ModuleInputKeyEnum.datasetSearchUsingReRank]: boolean;
}>;
export type DatasetSearchResponse = {
[ModuleOutputKeyEnum.responseData]: moduleDispatchResType;
@@ -27,7 +27,7 @@ export async function dispatchDatasetSearch(
props: DatasetSearchProps
): Promise<DatasetSearchResponse> {
const {
inputs: { datasets = [], similarity = 0.4, limit = 5, searchMode, userChatInput }
inputs: { datasets = [], similarity, limit = 1500, usingReRank, searchMode, userChatInput }
} = props as DatasetSearchProps;
if (!Array.isArray(datasets)) {
@@ -52,14 +52,21 @@ export async function dispatchDatasetSearch(
const concatQueries = [userChatInput];
// start search
const { searchRes, tokenLen } = await searchDatasetData({
const { searchRes, tokens, usingSimilarityFilter } = await searchDatasetData({
rawQuery: userChatInput,
queries: concatQueries,
model: vectorModel.model,
similarity,
limit,
datasetIds: datasets.map((item) => item.datasetId),
searchMode
searchMode,
usingReRank
});
const { total, modelName } = formatModelPrice2Store({
model: vectorModel.model,
inputLen: tokens,
type: ModelTypeEnum.vector
});
return {
@@ -67,17 +74,14 @@ export async function dispatchDatasetSearch(
unEmpty: searchRes.length > 0 ? true : undefined,
quoteQA: searchRes,
responseData: {
price: countModelPrice({
model: vectorModel.model,
tokens: tokenLen,
type: ModelTypeEnum.vector
}),
price: total,
query: concatQueries.join('\n'),
model: vectorModel.name,
tokens: tokenLen,
similarity,
model: modelName,
inputTokens: tokens,
similarity: usingSimilarityFilter ? similarity : undefined,
limit,
searchMode
searchMode,
searchUsingReRank: usingReRank
}
};
}

View File

@@ -4,7 +4,8 @@ import { ModuleInputKeyEnum, ModuleOutputKeyEnum } from '@fastgpt/global/core/mo
import { getHistories } from '../utils';
import { getAIApi } from '@fastgpt/service/core/ai/config';
import { replaceVariable } from '@fastgpt/global/common/string/tools';
import { getExtractModel } from '@/service/core/ai/model';
import { ModelTypeEnum, getExtractModel } from '@/service/core/ai/model';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
type Props = ModuleDispatchProps<{
[ModuleInputKeyEnum.aiModel]: string;
@@ -75,13 +76,22 @@ A: ${systemPrompt}
// );
// console.log(answer);
const tokens = result.usage?.total_tokens || 0;
const inputTokens = result.usage?.prompt_tokens || 0;
const outputTokens = result.usage?.completion_tokens || 0;
const { total, modelName } = formatModelPrice2Store({
model: extractModel.model,
inputLen: inputTokens,
outputLen: outputTokens,
type: ModelTypeEnum.extract
});
return {
[ModuleOutputKeyEnum.responseData]: {
price: extractModel.price * tokens,
model: extractModel.name || '',
tokens,
price: total,
model: modelName,
inputTokens,
outputTokens,
query: userChatInput,
textOutput: answer
},

View File

@@ -1,11 +1,11 @@
import { startQueue } from './utils/tools';
import { PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
import { initPg } from '@fastgpt/service/common/pg';
import { MongoUser } from '@fastgpt/service/support/user/schema';
import { connectMongo } from '@fastgpt/service/common/mongo/init';
import { hashStr } from '@fastgpt/global/common/string/tools';
import { createDefaultTeam } from '@fastgpt/service/support/user/team/controller';
import { exit } from 'process';
import { initVectorStore } from '@fastgpt/service/common/vectorStore/controller';
/**
* connect MongoDB and init data
@@ -14,7 +14,7 @@ export function connectToDatabase(): Promise<void> {
return connectMongo({
beforeHook: () => {},
afterHook: () => {
initPg();
initVectorStore();
// start queue
startQueue();
return initRootUser();

View File

@@ -0,0 +1,22 @@
import { ConcatBillProps, CreateBillProps } from '@fastgpt/global/support/wallet/bill/api';
import { addLog } from '@fastgpt/service/common/system/log';
import { POST } from '@fastgpt/service/common/api/plusRequest';
export function createBill(data: CreateBillProps) {
if (!global.systemEnv?.pluginBaseUrl) return;
if (data.total === 0) {
addLog.info('0 Bill', data);
}
try {
POST('/support/wallet/bill/createBill', data);
} catch (error) {}
}
export function concatBill(data: ConcatBillProps) {
if (!global.systemEnv?.pluginBaseUrl) return;
if (data.total === 0) {
addLog.info('0 Bill', data);
}
try {
POST('/support/wallet/bill/concatBill', data);
} catch (error) {}
}

View File

@@ -1,30 +1,11 @@
import { BillSourceEnum, PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
import { getAudioSpeechModel, getQAModel, getVectorModel } from '@/service/core/ai/model';
import { BillSourceEnum } from '@fastgpt/global/support/wallet/bill/constants';
import { ModelTypeEnum } from '@/service/core/ai/model';
import type { ChatHistoryItemResType } from '@fastgpt/global/core/chat/type.d';
import { formatPrice } from '@fastgpt/global/support/wallet/bill/tools';
import { formatStorePrice2Read } from '@fastgpt/global/support/wallet/bill/tools';
import { addLog } from '@fastgpt/service/common/system/log';
import type { ConcatBillProps, CreateBillProps } from '@fastgpt/global/support/wallet/bill/api.d';
import { POST } from '@fastgpt/service/common/api/plusRequest';
import { PostReRankProps } from '@fastgpt/global/core/ai/api';
export function createBill(data: CreateBillProps) {
if (!global.systemEnv?.pluginBaseUrl) return;
if (data.total === 0) {
addLog.info('0 Bill', data);
}
try {
POST('/support/wallet/bill/createBill', data);
} catch (error) {}
}
export function concatBill(data: ConcatBillProps) {
if (!global.systemEnv?.pluginBaseUrl) return;
if (data.total === 0) {
addLog.info('0 Bill', data);
}
try {
POST('/support/wallet/bill/concatBill', data);
} catch (error) {}
}
import { createBill, concatBill } from './controller';
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
export const pushChatBill = ({
appName,
@@ -54,14 +35,15 @@ export const pushChatBill = ({
moduleName: item.moduleName,
amount: item.price || 0,
model: item.model,
tokenLen: item.tokens
inputTokens: item.inputTokens,
outputTokens: item.outputTokens
}))
});
addLog.info(`finish completions`, {
source,
teamId,
tmbId,
price: formatPrice(total)
price: formatStorePrice2Read(total)
});
return { total };
};
@@ -70,26 +52,32 @@ export const pushQABill = async ({
teamId,
tmbId,
model,
totalTokens,
inputTokens,
outputTokens,
billId
}: {
teamId: string;
tmbId: string;
model: string;
totalTokens: number;
inputTokens: number;
outputTokens: number;
billId: string;
}) => {
// 获取模型单价格
const unitPrice = getQAModel(model).price;
// 计算价格
const total = unitPrice * totalTokens;
const { total } = formatModelPrice2Store({
model,
inputLen: inputTokens,
outputLen: outputTokens,
type: ModelTypeEnum.qa
});
concatBill({
billId,
teamId,
tmbId,
total,
tokens: totalTokens,
inputTokens,
outputTokens,
listIndex: 1
});
@@ -100,22 +88,24 @@ export const pushGenerateVectorBill = ({
billId,
teamId,
tmbId,
tokenLen,
tokens,
model,
source = BillSourceEnum.fastgpt
}: {
billId?: string;
teamId: string;
tmbId: string;
tokenLen: number;
tokens: number;
model: string;
source?: `${BillSourceEnum}`;
}) => {
// 计算价格. 至少为1
const vectorModel = getVectorModel(model);
const unitPrice = vectorModel.price || 0.2;
let total = unitPrice * tokenLen;
total = total > 1 ? total : 1;
let { total, modelName } = formatModelPrice2Store({
model,
inputLen: tokens,
type: ModelTypeEnum.vector
});
total = total < 1 ? 1 : total;
// 插入 Bill 记录
if (billId) {
@@ -124,22 +114,22 @@ export const pushGenerateVectorBill = ({
tmbId,
total,
billId,
tokens: tokenLen,
inputTokens: tokens,
listIndex: 0
});
} else {
createBill({
teamId,
tmbId,
appName: '索引生成',
appName: 'wallet.moduleName.index',
total,
source,
list: [
{
moduleName: '索引生成',
moduleName: 'wallet.moduleName.index',
amount: total,
model: vectorModel.name,
tokenLen
model: modelName,
inputTokens: tokens
}
]
});
@@ -148,28 +138,37 @@ export const pushGenerateVectorBill = ({
};
export const pushQuestionGuideBill = ({
tokens,
inputTokens,
outputTokens,
teamId,
tmbId
}: {
tokens: number;
inputTokens: number;
outputTokens: number;
teamId: string;
tmbId: string;
}) => {
const qgModel = global.qgModels[0];
const total = qgModel.price * tokens;
const { total, modelName } = formatModelPrice2Store({
inputLen: inputTokens,
outputLen: outputTokens,
model: qgModel.model,
type: ModelTypeEnum.qg
});
createBill({
teamId,
tmbId,
appName: '下一步指引',
appName: 'wallet.bill.Next Step Guide',
total,
source: BillSourceEnum.fastgpt,
list: [
{
moduleName: '下一步指引',
moduleName: 'wallet.bill.Next Step Guide',
amount: total,
model: qgModel.name,
tokenLen: tokens
model: modelName,
inputTokens,
outputTokens
}
]
});
@@ -178,20 +177,24 @@ export const pushQuestionGuideBill = ({
export function pushAudioSpeechBill({
appName = 'wallet.bill.Audio Speech',
model,
textLength,
textLen,
teamId,
tmbId,
source = BillSourceEnum.fastgpt
}: {
appName?: string;
model: string;
textLength: number;
textLen: number;
teamId: string;
tmbId: string;
source: `${BillSourceEnum}`;
}) {
const modelData = getAudioSpeechModel(model);
const total = modelData.price * textLength;
const { total, modelName } = formatModelPrice2Store({
model,
inputLen: textLen,
type: ModelTypeEnum.audioSpeech
});
createBill({
teamId,
tmbId,
@@ -202,8 +205,8 @@ export function pushAudioSpeechBill({
{
moduleName: appName,
amount: total,
model: modelData.name,
tokenLen: textLength
model: modelName,
textLen
}
]
});
@@ -218,11 +221,16 @@ export function pushWhisperBill({
tmbId: string;
duration: number;
}) {
const modelData = global.whisperModel;
const whisperModel = global.whisperModel;
if (!modelData) return;
if (!whisperModel) return;
const total = ((modelData.price * duration) / 60) * PRICE_SCALE;
const { total, modelName } = formatModelPrice2Store({
model: whisperModel.model,
inputLen: duration,
type: ModelTypeEnum.whisper,
multiple: 60
});
const name = 'wallet.bill.Whisper';
@@ -236,8 +244,8 @@ export function pushWhisperBill({
{
moduleName: name,
amount: total,
model: modelData.name,
tokenLen: duration
model: modelName,
duration
}
]
});
@@ -254,13 +262,16 @@ export function pushReRankBill({
source: `${BillSourceEnum}`;
inputs: PostReRankProps['inputs'];
}) {
const model = global.reRankModels[0];
if (!model) return { total: 0 };
const reRankModel = global.reRankModels[0];
if (!reRankModel) return { total: 0 };
const textLength = inputs.reduce((sum, item) => sum + item.text.length, 0);
const ratio = textLength / 1000;
const textLen = inputs.reduce((sum, item) => sum + item.text.length, 0);
const total = Math.ceil(model.price * PRICE_SCALE * ratio);
const { total, modelName } = formatModelPrice2Store({
model: reRankModel.model,
inputLen: textLen,
type: ModelTypeEnum.rerank
});
const name = 'wallet.bill.ReRank';
createBill({
@@ -273,8 +284,8 @@ export function pushReRankBill({
{
moduleName: name,
amount: total,
model: model.name,
tokenLen: textLength
model: modelName,
textLen
}
]
});

View File

@@ -1,6 +1,6 @@
import { ModelTypeEnum, getModelMap } from '@/service/core/ai/model';
import { AuthUserTypeEnum } from '@fastgpt/global/support/permission/constant';
import { BillSourceEnum } from '@fastgpt/global/support/wallet/bill/constants';
import { BillSourceEnum, PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
export function authType2BillSource({
authType,
@@ -17,16 +17,38 @@ export function authType2BillSource({
return BillSourceEnum.fastgpt;
}
export const countModelPrice = ({
export const formatModelPrice2Store = ({
model,
tokens,
type
inputLen = 0,
outputLen = 0,
type,
multiple = 1000
}: {
model: string;
tokens: number;
inputLen: number;
outputLen?: number;
type: `${ModelTypeEnum}`;
multiple?: number;
}) => {
const modelData = getModelMap?.[type]?.(model);
if (!modelData) return 0;
return modelData.price * tokens;
if (!modelData)
return {
inputTotal: 0,
outputTotal: 0,
total: 0,
modelName: ''
};
const inputTotal = modelData.inputPrice
? Math.ceil(modelData.inputPrice * (inputLen / multiple) * PRICE_SCALE)
: 0;
const outputTotal = modelData.outputPrice
? Math.ceil(modelData.outputPrice * (outputLen / multiple) * PRICE_SCALE)
: 0;
return {
modelName: modelData.name,
inputTotal: inputTotal,
outputTotal: outputTotal,
total: inputTotal + outputTotal
};
};