mirror of
https://github.com/labring/FastGPT.git
synced 2025-08-01 03:48:24 +00:00
V4.6.6-2 (#673)
This commit is contained in:
13
projects/app/src/service/common/string/jieba.ts
Normal file
13
projects/app/src/service/common/string/jieba.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { cut } from '@node-rs/jieba';
|
||||
import { stopWords } from '@fastgpt/global/common/string/jieba';
|
||||
|
||||
export function jiebaSplit({ text }: { text: string }) {
|
||||
const tokens = cut(text, true);
|
||||
|
||||
return (
|
||||
tokens
|
||||
.map((item) => item.replace(/[^\u4e00-\u9fa5a-zA-Z0-9\s]/g, '').trim())
|
||||
.filter((item) => item && !stopWords.has(item))
|
||||
.join(' ') || ''
|
||||
);
|
||||
}
|
@@ -24,13 +24,24 @@ export function getAudioSpeechModel(model?: string) {
|
||||
);
|
||||
}
|
||||
|
||||
export function getWhisperModel(model?: string) {
|
||||
return global.whisperModel;
|
||||
}
|
||||
|
||||
export function getReRankModel(model?: string) {
|
||||
return global.reRankModels.find((item) => item.model === model);
|
||||
}
|
||||
|
||||
export enum ModelTypeEnum {
|
||||
chat = 'chat',
|
||||
qa = 'qa',
|
||||
cq = 'cq',
|
||||
extract = 'extract',
|
||||
qg = 'qg',
|
||||
vector = 'vector'
|
||||
vector = 'vector',
|
||||
audioSpeech = 'audioSpeech',
|
||||
whisper = 'whisper',
|
||||
rerank = 'rerank'
|
||||
}
|
||||
export const getModelMap = {
|
||||
[ModelTypeEnum.chat]: getChatModel,
|
||||
@@ -38,5 +49,8 @@ export const getModelMap = {
|
||||
[ModelTypeEnum.cq]: getCQModel,
|
||||
[ModelTypeEnum.extract]: getExtractModel,
|
||||
[ModelTypeEnum.qg]: getQGModel,
|
||||
[ModelTypeEnum.vector]: getVectorModel
|
||||
[ModelTypeEnum.vector]: getVectorModel,
|
||||
[ModelTypeEnum.audioSpeech]: getAudioSpeechModel,
|
||||
[ModelTypeEnum.whisper]: getWhisperModel,
|
||||
[ModelTypeEnum.rerank]: getReRankModel
|
||||
};
|
||||
|
@@ -20,8 +20,14 @@ export function reRankRecall({ query, inputs }: PostReRankProps) {
|
||||
Authorization: `Bearer ${model.requestAuth}`
|
||||
}
|
||||
}
|
||||
).then((data) => {
|
||||
console.log('rerank time:', Date.now() - start);
|
||||
return data;
|
||||
});
|
||||
)
|
||||
.then((data) => {
|
||||
console.log('rerank time:', Date.now() - start);
|
||||
return data;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
|
||||
return [];
|
||||
});
|
||||
}
|
||||
|
@@ -1,73 +0,0 @@
|
||||
import { getAIApi } from '@fastgpt/service/core/ai/config';
|
||||
|
||||
export type GetVectorProps = {
|
||||
model: string;
|
||||
input: string | string[];
|
||||
};
|
||||
|
||||
// text to vector
|
||||
export async function getVectorsByText({
|
||||
model = 'text-embedding-ada-002',
|
||||
input
|
||||
}: GetVectorProps) {
|
||||
try {
|
||||
if (typeof input === 'string' && !input) {
|
||||
return Promise.reject({
|
||||
code: 500,
|
||||
message: 'input is empty'
|
||||
});
|
||||
} else if (Array.isArray(input)) {
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
if (!input[i]) {
|
||||
return Promise.reject({
|
||||
code: 500,
|
||||
message: 'input array is empty'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 chatAPI
|
||||
const ai = getAIApi();
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const result = await ai.embeddings
|
||||
.create({
|
||||
model,
|
||||
input
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.data) {
|
||||
return Promise.reject('Embedding API 404');
|
||||
}
|
||||
if (!res?.data?.[0]?.embedding) {
|
||||
console.log(res?.data);
|
||||
// @ts-ignore
|
||||
return Promise.reject(res.data?.err?.message || 'Embedding API Error');
|
||||
}
|
||||
return {
|
||||
tokenLen: res.usage.total_tokens || 0,
|
||||
vectors: await Promise.all(res.data.map((item) => unityDimensional(item.embedding)))
|
||||
};
|
||||
});
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.log(`Embedding Error`, error);
|
||||
|
||||
return Promise.reject(error);
|
||||
}
|
||||
}
|
||||
|
||||
function unityDimensional(vector: number[]) {
|
||||
if (vector.length > 1536) {
|
||||
console.log(`当前向量维度为: ${vector.length}, 向量维度不能超过 1536, 已自动截取前 1536 维度`);
|
||||
return vector.slice(0, 1536);
|
||||
}
|
||||
let resultVector = vector;
|
||||
const vectorLen = vector.length;
|
||||
|
||||
const zeroVector = new Array(1536 - vectorLen).fill(0);
|
||||
|
||||
return resultVector.concat(zeroVector);
|
||||
}
|
@@ -4,12 +4,30 @@ import {
|
||||
PatchIndexesProps,
|
||||
UpdateDatasetDataProps
|
||||
} from '@fastgpt/global/core/dataset/controller';
|
||||
import { deletePgDataById } from '@fastgpt/service/core/dataset/data/pg';
|
||||
import { insertData2Pg, updatePgDataById } from './pg';
|
||||
import {
|
||||
insertDatasetDataVector,
|
||||
recallFromVectorStore,
|
||||
updateDatasetDataVector
|
||||
} from '@fastgpt/service/common/vectorStore/controller';
|
||||
import { Types } from 'mongoose';
|
||||
import { DatasetDataIndexTypeEnum } from '@fastgpt/global/core/dataset/constant';
|
||||
import {
|
||||
DatasetDataIndexTypeEnum,
|
||||
DatasetSearchModeEnum,
|
||||
DatasetSearchModeMap,
|
||||
SearchScoreTypeEnum
|
||||
} from '@fastgpt/global/core/dataset/constant';
|
||||
import { getDefaultIndex } from '@fastgpt/global/core/dataset/utils';
|
||||
import { jiebaSplit } from '../utils';
|
||||
import { jiebaSplit } from '@/service/common/string/jieba';
|
||||
import { deleteDatasetDataVector } from '@fastgpt/service/common/vectorStore/controller';
|
||||
import { getVectorsByText } from '@fastgpt/service/core/ai/embedding';
|
||||
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
|
||||
import {
|
||||
DatasetDataSchemaType,
|
||||
SearchDataResponseItemType
|
||||
} from '@fastgpt/global/core/dataset/type';
|
||||
import { reRankRecall } from '../../ai/rerank';
|
||||
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import { hashStr } from '@fastgpt/global/common/string/tools';
|
||||
|
||||
/* insert data.
|
||||
* 1. create data id
|
||||
@@ -50,17 +68,17 @@ export async function insertData2Dataset({
|
||||
}))
|
||||
: [getDefaultIndex({ q, a })];
|
||||
|
||||
// insert to pg
|
||||
// insert to vector store
|
||||
const result = await Promise.all(
|
||||
indexes.map((item) =>
|
||||
insertData2Pg({
|
||||
mongoDataId: String(id),
|
||||
input: item.text,
|
||||
insertDatasetDataVector({
|
||||
query: item.text,
|
||||
model,
|
||||
teamId,
|
||||
tmbId,
|
||||
datasetId,
|
||||
collectionId
|
||||
collectionId,
|
||||
dataId: String(id)
|
||||
})
|
||||
)
|
||||
);
|
||||
@@ -84,7 +102,7 @@ export async function insertData2Dataset({
|
||||
|
||||
return {
|
||||
insertId: _id,
|
||||
tokenLen: result.reduce((acc, cur) => acc + cur.tokenLen, 0)
|
||||
tokens: result.reduce((acc, cur) => acc + cur.tokens, 0)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -172,35 +190,40 @@ export async function updateData2Dataset({
|
||||
const result = await Promise.all(
|
||||
patchResult.map(async (item) => {
|
||||
if (item.type === 'create') {
|
||||
const result = await insertData2Pg({
|
||||
mongoDataId: dataId,
|
||||
input: item.index.text,
|
||||
const result = await insertDatasetDataVector({
|
||||
query: item.index.text,
|
||||
model,
|
||||
teamId: mongoData.teamId,
|
||||
tmbId: mongoData.tmbId,
|
||||
datasetId: mongoData.datasetId,
|
||||
collectionId: mongoData.collectionId
|
||||
collectionId: mongoData.collectionId,
|
||||
dataId
|
||||
});
|
||||
item.index.dataId = result.insertId;
|
||||
return result;
|
||||
}
|
||||
if (item.type === 'update' && item.index.dataId) {
|
||||
return updatePgDataById({
|
||||
return updateDatasetDataVector({
|
||||
id: item.index.dataId,
|
||||
input: item.index.text,
|
||||
query: item.index.text,
|
||||
model
|
||||
});
|
||||
}
|
||||
if (item.type === 'delete' && item.index.dataId) {
|
||||
return deletePgDataById(['id', item.index.dataId]);
|
||||
await deleteDatasetDataVector({
|
||||
id: item.index.dataId
|
||||
});
|
||||
return {
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
return {
|
||||
tokenLen: 0
|
||||
tokens: 0
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const tokenLen = result.reduce((acc, cur) => acc + cur.tokenLen, 0);
|
||||
const tokens = result.reduce((acc, cur) => acc + cur.tokens, 0);
|
||||
|
||||
// update mongo
|
||||
mongoData.q = q || mongoData.q;
|
||||
@@ -211,6 +234,457 @@ export async function updateData2Dataset({
|
||||
await mongoData.save();
|
||||
|
||||
return {
|
||||
tokenLen
|
||||
tokens
|
||||
};
|
||||
}
|
||||
|
||||
export async function searchDatasetData(props: {
|
||||
model: string;
|
||||
similarity?: number; // min distance
|
||||
limit: number; // max Token limit
|
||||
datasetIds: string[];
|
||||
searchMode?: `${DatasetSearchModeEnum}`;
|
||||
usingReRank?: boolean;
|
||||
rawQuery: string;
|
||||
queries: string[];
|
||||
}) {
|
||||
let {
|
||||
rawQuery,
|
||||
queries,
|
||||
model,
|
||||
similarity = 0,
|
||||
limit: maxTokens,
|
||||
searchMode = DatasetSearchModeEnum.embedding,
|
||||
usingReRank = false,
|
||||
datasetIds = []
|
||||
} = props;
|
||||
|
||||
/* init params */
|
||||
searchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding;
|
||||
usingReRank = usingReRank && global.reRankModels.length > 0;
|
||||
|
||||
// Compatible with topk limit
|
||||
if (maxTokens < 50) {
|
||||
maxTokens = 1500;
|
||||
}
|
||||
let set = new Set<string>();
|
||||
let usingSimilarityFilter = false;
|
||||
|
||||
/* function */
|
||||
const countRecallLimit = () => {
|
||||
const oneChunkToken = 50;
|
||||
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
|
||||
|
||||
// Increase search range, reduce hnsw loss. 20 ~ 100
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 100),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < value < 100
|
||||
if (searchMode === DatasetSearchModeEnum.fullTextRecall) {
|
||||
return {
|
||||
embeddingLimit: 0,
|
||||
fullTextLimit: Math.min(estimatedLen, 50)
|
||||
};
|
||||
}
|
||||
// mixed
|
||||
// 50 < 2*limit < embedding < 80
|
||||
// 20 < limit < fullTextLimit < 40
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 80),
|
||||
fullTextLimit: Math.min(estimatedLen, 40)
|
||||
};
|
||||
};
|
||||
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
|
||||
const { vectors, tokens } = await getVectorsByText({
|
||||
model,
|
||||
input: [query]
|
||||
});
|
||||
|
||||
const { results } = await recallFromVectorStore({
|
||||
vectors,
|
||||
limit,
|
||||
datasetIds
|
||||
});
|
||||
|
||||
// get q and a
|
||||
const [collections, dataList] = await Promise.all([
|
||||
MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: results.map((item) => item.collectionId) }
|
||||
},
|
||||
'name fileId rawLink'
|
||||
).lean(),
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
_id: { $in: results.map((item) => item.dataId?.trim()) }
|
||||
},
|
||||
'datasetId collectionId q a chunkIndex indexes'
|
||||
).lean()
|
||||
]);
|
||||
|
||||
const formatResult = results
|
||||
.map((item, index) => {
|
||||
const collection = collections.find(
|
||||
(collection) => String(collection._id) === item.collectionId
|
||||
);
|
||||
const data = dataList.find((data) => String(data._id) === item.dataId);
|
||||
|
||||
// if collection or data UnExist, the relational mongo data already deleted
|
||||
if (!collection || !data) return null;
|
||||
|
||||
const result: SearchDataResponseItemType = {
|
||||
id: String(data._id),
|
||||
q: data.q,
|
||||
a: data.a,
|
||||
chunkIndex: data.chunkIndex,
|
||||
indexes: data.indexes,
|
||||
datasetId: String(data.datasetId),
|
||||
collectionId: String(data.collectionId),
|
||||
sourceName: collection.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
score: [{ type: SearchScoreTypeEnum.embedding, value: item.score, index }]
|
||||
};
|
||||
|
||||
return result;
|
||||
})
|
||||
.filter((item) => item !== null) as SearchDataResponseItemType[];
|
||||
|
||||
return {
|
||||
embeddingRecallResults: formatResult,
|
||||
tokens
|
||||
};
|
||||
};
|
||||
const fullTextRecall = async ({
|
||||
query,
|
||||
limit
|
||||
}: {
|
||||
query: string;
|
||||
limit: number;
|
||||
}): Promise<{
|
||||
fullTextRecallResults: SearchDataResponseItemType[];
|
||||
tokenLen: number;
|
||||
}> => {
|
||||
if (limit === 0) {
|
||||
return {
|
||||
fullTextRecallResults: [],
|
||||
tokenLen: 0
|
||||
};
|
||||
}
|
||||
|
||||
let searchResults = (
|
||||
await Promise.all(
|
||||
datasetIds.map((id) =>
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
datasetId: id,
|
||||
$text: { $search: jiebaSplit({ text: query }) }
|
||||
},
|
||||
{
|
||||
score: { $meta: 'textScore' },
|
||||
_id: 1,
|
||||
datasetId: 1,
|
||||
collectionId: 1,
|
||||
q: 1,
|
||||
a: 1,
|
||||
indexes: 1,
|
||||
chunkIndex: 1
|
||||
}
|
||||
)
|
||||
.sort({ score: { $meta: 'textScore' } })
|
||||
.limit(limit)
|
||||
.lean()
|
||||
)
|
||||
)
|
||||
).flat() as (DatasetDataSchemaType & { score: number })[];
|
||||
|
||||
// resort
|
||||
searchResults.sort((a, b) => b.score - a.score);
|
||||
searchResults.slice(0, limit);
|
||||
|
||||
const collections = await MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: searchResults.map((item) => item.collectionId) }
|
||||
},
|
||||
'_id name fileId rawLink'
|
||||
);
|
||||
|
||||
return {
|
||||
fullTextRecallResults: searchResults.map((item, index) => {
|
||||
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
|
||||
return {
|
||||
id: String(item._id),
|
||||
datasetId: String(item.datasetId),
|
||||
collectionId: String(item.collectionId),
|
||||
sourceName: collection?.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
q: item.q,
|
||||
a: item.a,
|
||||
chunkIndex: item.chunkIndex,
|
||||
indexes: item.indexes,
|
||||
score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }]
|
||||
};
|
||||
}),
|
||||
tokenLen: 0
|
||||
};
|
||||
};
|
||||
const reRankSearchResult = async ({
|
||||
data,
|
||||
query
|
||||
}: {
|
||||
data: SearchDataResponseItemType[];
|
||||
query: string;
|
||||
}): Promise<SearchDataResponseItemType[]> => {
|
||||
try {
|
||||
const results = await reRankRecall({
|
||||
query,
|
||||
inputs: data.map((item) => ({
|
||||
id: item.id,
|
||||
text: `${item.q}\n${item.a}`
|
||||
}))
|
||||
});
|
||||
|
||||
if (!Array.isArray(results)) return [];
|
||||
|
||||
// add new score to data
|
||||
const mergeResult = results
|
||||
.map((item, index) => {
|
||||
const target = data.find((dataItem) => dataItem.id === item.id);
|
||||
if (!target) return null;
|
||||
const score = item.score || 0;
|
||||
|
||||
return {
|
||||
...target,
|
||||
score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }]
|
||||
};
|
||||
})
|
||||
.filter(Boolean) as SearchDataResponseItemType[];
|
||||
|
||||
return mergeResult;
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
|
||||
const results: SearchDataResponseItemType[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
for (let i = 0; i < list.length; i++) {
|
||||
const item = list[i];
|
||||
totalTokens += countPromptTokens(item.q + item.a);
|
||||
if (totalTokens > maxTokens + 500) {
|
||||
break;
|
||||
}
|
||||
results.push(item);
|
||||
if (totalTokens > maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return results.length === 0 ? list.slice(0, 1) : results;
|
||||
};
|
||||
const multiQueryRecall = async ({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
}: {
|
||||
embeddingLimit: number;
|
||||
fullTextLimit: number;
|
||||
}) => {
|
||||
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
|
||||
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
|
||||
minAmount = Math.min(resultList.length, minAmount);
|
||||
|
||||
const map: Record<
|
||||
string,
|
||||
{
|
||||
amount: number;
|
||||
data: SearchDataResponseItemType;
|
||||
}
|
||||
> = {};
|
||||
|
||||
for (const list of resultList) {
|
||||
for (const item of list) {
|
||||
map[item.id] = map[item.id]
|
||||
? {
|
||||
amount: map[item.id].amount + 1,
|
||||
data: item
|
||||
}
|
||||
: {
|
||||
amount: 1,
|
||||
data: item
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return Object.values(map)
|
||||
.filter((item) => item.amount >= minAmount)
|
||||
.map((item) => item.data);
|
||||
};
|
||||
|
||||
// multi query recall
|
||||
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
|
||||
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
|
||||
let embTokens = 0;
|
||||
for await (const query of queries) {
|
||||
const [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
|
||||
embeddingRecall({
|
||||
query,
|
||||
limit: embeddingLimit
|
||||
}),
|
||||
fullTextRecall({
|
||||
query,
|
||||
limit: fullTextLimit
|
||||
})
|
||||
]);
|
||||
embTokens += tokens;
|
||||
|
||||
embeddingRecallResList.push(embeddingRecallResults);
|
||||
fullTextRecallResList.push(fullTextRecallResults);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: embTokens,
|
||||
embeddingRecallResults: embeddingRecallResList[0],
|
||||
fullTextRecallResults: fullTextRecallResList[0]
|
||||
};
|
||||
};
|
||||
const rrfConcat = (
|
||||
arr: { k: number; list: SearchDataResponseItemType[] }[]
|
||||
): SearchDataResponseItemType[] => {
|
||||
const map = new Map<string, SearchDataResponseItemType & { rrfScore: number }>();
|
||||
|
||||
// rrf
|
||||
arr.forEach((item) => {
|
||||
const k = item.k;
|
||||
|
||||
item.list.forEach((data, index) => {
|
||||
const rank = index + 1;
|
||||
const score = 1 / (k + rank);
|
||||
|
||||
const record = map.get(data.id);
|
||||
if (record) {
|
||||
// 合并两个score,有相同type的score,取最大值
|
||||
const concatScore = [...record.score];
|
||||
for (const dataItem of data.score) {
|
||||
const sameScore = concatScore.find((item) => item.type === dataItem.type);
|
||||
if (sameScore) {
|
||||
sameScore.value = Math.max(sameScore.value, dataItem.value);
|
||||
} else {
|
||||
concatScore.push(dataItem);
|
||||
}
|
||||
}
|
||||
|
||||
map.set(data.id, {
|
||||
...record,
|
||||
score: concatScore,
|
||||
rrfScore: record.rrfScore + score
|
||||
});
|
||||
} else {
|
||||
map.set(data.id, {
|
||||
...data,
|
||||
rrfScore: score
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// sort
|
||||
const mapArray = Array.from(map.values());
|
||||
const results = mapArray.sort((a, b) => b.rrfScore - a.rrfScore);
|
||||
|
||||
return results.map((item, index) => {
|
||||
item.score.push({
|
||||
type: SearchScoreTypeEnum.rrf,
|
||||
value: item.rrfScore,
|
||||
index
|
||||
});
|
||||
// @ts-ignore
|
||||
delete item.rrfScore;
|
||||
return item;
|
||||
});
|
||||
};
|
||||
|
||||
/* main step */
|
||||
// count limit
|
||||
const { embeddingLimit, fullTextLimit } = countRecallLimit();
|
||||
|
||||
// recall
|
||||
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
});
|
||||
|
||||
// ReRank results
|
||||
const reRankResults = await (async () => {
|
||||
if (!usingReRank) return [];
|
||||
|
||||
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
|
||||
const concatRecallResults = embeddingRecallResults.concat(
|
||||
fullTextRecallResults.filter((item) => !set.has(item.id))
|
||||
);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = concatRecallResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
return reRankSearchResult({
|
||||
query: rawQuery,
|
||||
data: filterSameDataResults
|
||||
});
|
||||
})();
|
||||
|
||||
// embedding recall and fullText recall rrf concat
|
||||
const rrfConcatResults = rrfConcat([
|
||||
{ k: 60, list: embeddingRecallResults },
|
||||
{ k: 60, list: fullTextRecallResults },
|
||||
{ k: 60, list: reRankResults }
|
||||
]);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = rrfConcatResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
|
||||
// score filter
|
||||
const scoreFilter = (() => {
|
||||
if (usingReRank) {
|
||||
usingSimilarityFilter = true;
|
||||
|
||||
return filterSameDataResults.filter((item) => {
|
||||
const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank);
|
||||
if (reRankScore && reRankScore.value < similarity) return false;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return filterSameDataResults.filter((item) => {
|
||||
usingSimilarityFilter = true;
|
||||
|
||||
const embeddingScore = item.score.find(
|
||||
(item) => item.type === SearchScoreTypeEnum.embedding
|
||||
);
|
||||
if (embeddingScore && embeddingScore.value < similarity) return false;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
return filterSameDataResults;
|
||||
})();
|
||||
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(scoreFilter, maxTokens),
|
||||
tokens,
|
||||
usingSimilarityFilter
|
||||
};
|
||||
}
|
||||
|
@@ -1,478 +0,0 @@
|
||||
import { DatasetSearchModeEnum, PgDatasetTableName } from '@fastgpt/global/core/dataset/constant';
|
||||
import type {
|
||||
DatasetDataSchemaType,
|
||||
SearchDataResponseItemType
|
||||
} from '@fastgpt/global/core/dataset/type.d';
|
||||
import { PgClient } from '@fastgpt/service/common/pg';
|
||||
import { getVectorsByText } from '@/service/core/ai/vector';
|
||||
import { delay } from '@fastgpt/global/common/system/utils';
|
||||
import { PgSearchRawType } from '@fastgpt/global/core/dataset/api';
|
||||
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
|
||||
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
|
||||
import { jiebaSplit } from '../utils';
|
||||
import { reRankRecall } from '../../ai/rerank';
|
||||
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import { hashStr } from '@fastgpt/global/common/string/tools';
|
||||
|
||||
export async function insertData2Pg(props: {
|
||||
mongoDataId: string;
|
||||
input: string;
|
||||
model: string;
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
datasetId: string;
|
||||
collectionId: string;
|
||||
retry?: number;
|
||||
}): Promise<{ insertId: string; vectors: number[][]; tokenLen: number }> {
|
||||
const { mongoDataId, input, model, teamId, tmbId, datasetId, collectionId, retry = 3 } = props;
|
||||
try {
|
||||
// get vector
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [input]
|
||||
});
|
||||
const { rows } = await PgClient.insert(PgDatasetTableName, {
|
||||
values: [
|
||||
[
|
||||
{ key: 'vector', value: `[${vectors[0]}]` },
|
||||
{ key: 'team_id', value: String(teamId) },
|
||||
{ key: 'tmb_id', value: String(tmbId) },
|
||||
{ key: 'dataset_id', value: datasetId },
|
||||
{ key: 'collection_id', value: collectionId },
|
||||
{ key: 'data_id', value: String(mongoDataId) }
|
||||
]
|
||||
]
|
||||
});
|
||||
return {
|
||||
insertId: rows[0].id,
|
||||
vectors,
|
||||
tokenLen
|
||||
};
|
||||
} catch (error) {
|
||||
if (retry <= 0) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
await delay(500);
|
||||
return insertData2Pg({
|
||||
...props,
|
||||
retry: retry - 1
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export async function updatePgDataById({
|
||||
id,
|
||||
input,
|
||||
model
|
||||
}: {
|
||||
id: string;
|
||||
input: string;
|
||||
model: string;
|
||||
}) {
|
||||
let retry = 2;
|
||||
async function updatePg(): Promise<{ vectors: number[][]; tokenLen: number }> {
|
||||
try {
|
||||
// get vector
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [input]
|
||||
});
|
||||
// update pg
|
||||
await PgClient.update(PgDatasetTableName, {
|
||||
where: [['id', id]],
|
||||
values: [{ key: 'vector', value: `[${vectors[0]}]` }]
|
||||
});
|
||||
return {
|
||||
vectors,
|
||||
tokenLen
|
||||
};
|
||||
} catch (error) {
|
||||
if (--retry < 0) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
await delay(500);
|
||||
return updatePg();
|
||||
}
|
||||
}
|
||||
return updatePg();
|
||||
}
|
||||
|
||||
// ------------------ search start ------------------
|
||||
type SearchProps = {
|
||||
model: string;
|
||||
similarity?: number; // min distance
|
||||
limit: number; // max Token limit
|
||||
datasetIds: string[];
|
||||
searchMode?: `${DatasetSearchModeEnum}`;
|
||||
};
|
||||
export async function searchDatasetData(
|
||||
props: SearchProps & { rawQuery: string; queries: string[] }
|
||||
) {
|
||||
let {
|
||||
rawQuery,
|
||||
queries,
|
||||
model,
|
||||
similarity = 0,
|
||||
limit: maxTokens,
|
||||
searchMode = DatasetSearchModeEnum.embedding,
|
||||
datasetIds = []
|
||||
} = props;
|
||||
|
||||
/* init params */
|
||||
searchMode = global.systemEnv?.pluginBaseUrl ? searchMode : DatasetSearchModeEnum.embedding;
|
||||
// Compatible with topk limit
|
||||
if (maxTokens < 50) {
|
||||
maxTokens = 1500;
|
||||
}
|
||||
const rerank =
|
||||
global.reRankModels?.[0] &&
|
||||
(searchMode === DatasetSearchModeEnum.embeddingReRank ||
|
||||
searchMode === DatasetSearchModeEnum.embFullTextReRank);
|
||||
let set = new Set<string>();
|
||||
|
||||
/* function */
|
||||
const countRecallLimit = () => {
|
||||
const oneChunkToken = 50;
|
||||
const estimatedLen = Math.max(20, Math.ceil(maxTokens / oneChunkToken));
|
||||
|
||||
// Increase search range, reduce hnsw loss. 20 ~ 100
|
||||
if (searchMode === DatasetSearchModeEnum.embedding) {
|
||||
return {
|
||||
embeddingLimit: Math.min(estimatedLen, 100),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < value < 100
|
||||
if (searchMode === DatasetSearchModeEnum.embeddingReRank) {
|
||||
return {
|
||||
embeddingLimit: Math.min(100, Math.max(50, estimatedLen * 2)),
|
||||
fullTextLimit: 0
|
||||
};
|
||||
}
|
||||
// 50 < 2*limit < embedding < 80
|
||||
// 20 < limit < fullTextLimit < 40
|
||||
return {
|
||||
embeddingLimit: Math.min(80, Math.max(50, estimatedLen * 2)),
|
||||
fullTextLimit: Math.min(40, Math.max(20, estimatedLen))
|
||||
};
|
||||
};
|
||||
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
|
||||
const { vectors, tokenLen } = await getVectorsByText({
|
||||
model,
|
||||
input: [query]
|
||||
});
|
||||
|
||||
const results: any = await PgClient.query(
|
||||
`BEGIN;
|
||||
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
|
||||
select id, collection_id, data_id, (vector <#> '[${vectors[0]}]') * -1 AS score
|
||||
from ${PgDatasetTableName}
|
||||
where dataset_id IN (${datasetIds.map((id) => `'${String(id)}'`).join(',')})
|
||||
${rerank ? '' : `AND vector <#> '[${vectors[0]}]' < -${similarity}`}
|
||||
order by score desc limit ${limit};
|
||||
COMMIT;`
|
||||
);
|
||||
|
||||
const rows = results?.[2]?.rows as PgSearchRawType[];
|
||||
|
||||
// concat same data_id
|
||||
const filterRows: PgSearchRawType[] = [];
|
||||
let set = new Set<string>();
|
||||
for (const row of rows) {
|
||||
if (!set.has(row.data_id)) {
|
||||
filterRows.push(row);
|
||||
set.add(row.data_id);
|
||||
}
|
||||
}
|
||||
|
||||
// get q and a
|
||||
const [collections, dataList] = await Promise.all([
|
||||
MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: filterRows.map((item) => item.collection_id) }
|
||||
},
|
||||
'name fileId rawLink'
|
||||
).lean(),
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
_id: { $in: filterRows.map((item) => item.data_id?.trim()) }
|
||||
},
|
||||
'datasetId collectionId q a chunkIndex indexes'
|
||||
).lean()
|
||||
]);
|
||||
const formatResult = filterRows
|
||||
.map((item) => {
|
||||
const collection = collections.find(
|
||||
(collection) => String(collection._id) === item.collection_id
|
||||
);
|
||||
const data = dataList.find((data) => String(data._id) === item.data_id);
|
||||
|
||||
// if collection or data UnExist, the relational mongo data already deleted
|
||||
if (!collection || !data) return null;
|
||||
|
||||
return {
|
||||
id: String(data._id),
|
||||
q: data.q,
|
||||
a: data.a,
|
||||
chunkIndex: data.chunkIndex,
|
||||
indexes: data.indexes,
|
||||
datasetId: String(data.datasetId),
|
||||
collectionId: String(data.collectionId),
|
||||
sourceName: collection.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
score: item.score
|
||||
};
|
||||
})
|
||||
.filter((item) => item !== null) as SearchDataResponseItemType[];
|
||||
|
||||
return {
|
||||
embeddingRecallResults: formatResult,
|
||||
tokenLen
|
||||
};
|
||||
};
|
||||
const fullTextRecall = async ({
|
||||
query,
|
||||
limit
|
||||
}: {
|
||||
query: string;
|
||||
limit: number;
|
||||
}): Promise<{
|
||||
fullTextRecallResults: SearchDataResponseItemType[];
|
||||
tokenLen: number;
|
||||
}> => {
|
||||
if (limit === 0) {
|
||||
return {
|
||||
fullTextRecallResults: [],
|
||||
tokenLen: 0
|
||||
};
|
||||
}
|
||||
|
||||
let searchResults = (
|
||||
await Promise.all(
|
||||
datasetIds.map((id) =>
|
||||
MongoDatasetData.find(
|
||||
{
|
||||
datasetId: id,
|
||||
$text: { $search: jiebaSplit({ text: query }) }
|
||||
},
|
||||
{
|
||||
score: { $meta: 'textScore' },
|
||||
_id: 1,
|
||||
datasetId: 1,
|
||||
collectionId: 1,
|
||||
q: 1,
|
||||
a: 1,
|
||||
indexes: 1,
|
||||
chunkIndex: 1
|
||||
}
|
||||
)
|
||||
.sort({ score: { $meta: 'textScore' } })
|
||||
.limit(limit)
|
||||
.lean()
|
||||
)
|
||||
)
|
||||
).flat() as (DatasetDataSchemaType & { score: number })[];
|
||||
|
||||
// resort
|
||||
searchResults.sort((a, b) => b.score - a.score);
|
||||
searchResults.slice(0, limit);
|
||||
|
||||
const collections = await MongoDatasetCollection.find(
|
||||
{
|
||||
_id: { $in: searchResults.map((item) => item.collectionId) }
|
||||
},
|
||||
'_id name fileId rawLink'
|
||||
);
|
||||
|
||||
return {
|
||||
fullTextRecallResults: searchResults.map((item) => {
|
||||
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
|
||||
return {
|
||||
id: String(item._id),
|
||||
datasetId: String(item.datasetId),
|
||||
collectionId: String(item.collectionId),
|
||||
sourceName: collection?.name || '',
|
||||
sourceId: collection?.fileId || collection?.rawLink,
|
||||
q: item.q,
|
||||
a: item.a,
|
||||
chunkIndex: item.chunkIndex,
|
||||
indexes: item.indexes,
|
||||
// @ts-ignore
|
||||
score: item.score
|
||||
};
|
||||
}),
|
||||
tokenLen: 0
|
||||
};
|
||||
};
|
||||
const reRankSearchResult = async ({
|
||||
data,
|
||||
query
|
||||
}: {
|
||||
data: SearchDataResponseItemType[];
|
||||
query: string;
|
||||
}): Promise<SearchDataResponseItemType[]> => {
|
||||
try {
|
||||
const results = await reRankRecall({
|
||||
query,
|
||||
inputs: data.map((item) => ({
|
||||
id: item.id,
|
||||
text: `${item.q}\n${item.a}`
|
||||
}))
|
||||
});
|
||||
|
||||
if (!Array.isArray(results)) return data;
|
||||
|
||||
// add new score to data
|
||||
const mergeResult = results
|
||||
.map((item) => {
|
||||
const target = data.find((dataItem) => dataItem.id === item.id);
|
||||
if (!target) return null;
|
||||
return {
|
||||
...target,
|
||||
score: item.score || target.score
|
||||
};
|
||||
})
|
||||
.filter(Boolean) as SearchDataResponseItemType[];
|
||||
|
||||
return mergeResult;
|
||||
} catch (error) {
|
||||
return data;
|
||||
}
|
||||
};
|
||||
const filterResultsByMaxTokens = (list: SearchDataResponseItemType[], maxTokens: number) => {
|
||||
const results: SearchDataResponseItemType[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
for (let i = 0; i < list.length; i++) {
|
||||
const item = list[i];
|
||||
totalTokens += countPromptTokens(item.q + item.a);
|
||||
if (totalTokens > maxTokens + 500) {
|
||||
break;
|
||||
}
|
||||
results.push(item);
|
||||
if (totalTokens > maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return results.length === 0 ? list.slice(0, 1) : results;
|
||||
};
|
||||
const multiQueryRecall = async ({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
}: {
|
||||
embeddingLimit: number;
|
||||
fullTextLimit: number;
|
||||
}) => {
|
||||
// In a group n recall, as long as one of the data appears minAmount of times, it is retained
|
||||
const getIntersection = (resultList: SearchDataResponseItemType[][], minAmount = 1) => {
|
||||
minAmount = Math.min(resultList.length, minAmount);
|
||||
|
||||
const map: Record<
|
||||
string,
|
||||
{
|
||||
amount: number;
|
||||
data: SearchDataResponseItemType;
|
||||
}
|
||||
> = {};
|
||||
|
||||
for (const list of resultList) {
|
||||
for (const item of list) {
|
||||
map[item.id] = map[item.id]
|
||||
? {
|
||||
amount: map[item.id].amount + 1,
|
||||
data: item
|
||||
}
|
||||
: {
|
||||
amount: 1,
|
||||
data: item
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return Object.values(map)
|
||||
.filter((item) => item.amount >= minAmount)
|
||||
.map((item) => item.data);
|
||||
};
|
||||
|
||||
// multi query recall
|
||||
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
|
||||
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
|
||||
let embTokens = 0;
|
||||
for await (const query of queries) {
|
||||
const [{ tokenLen, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
|
||||
embeddingRecall({
|
||||
query,
|
||||
limit: embeddingLimit
|
||||
}),
|
||||
fullTextRecall({
|
||||
query,
|
||||
limit: fullTextLimit
|
||||
})
|
||||
]);
|
||||
embTokens += tokenLen;
|
||||
|
||||
embeddingRecallResList.push(embeddingRecallResults);
|
||||
fullTextRecallResList.push(fullTextRecallResults);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: embTokens,
|
||||
embeddingRecallResults: getIntersection(embeddingRecallResList, 2),
|
||||
fullTextRecallResults: getIntersection(fullTextRecallResList, 2)
|
||||
};
|
||||
};
|
||||
|
||||
/* main step */
|
||||
// count limit
|
||||
const { embeddingLimit, fullTextLimit } = countRecallLimit();
|
||||
|
||||
// recall
|
||||
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
|
||||
embeddingLimit,
|
||||
fullTextLimit
|
||||
});
|
||||
|
||||
// concat recall results
|
||||
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
|
||||
const concatRecallResults = embeddingRecallResults.concat(
|
||||
fullTextRecallResults.filter((item) => !set.has(item.id))
|
||||
);
|
||||
|
||||
// remove same q and a data
|
||||
set = new Set<string>();
|
||||
const filterSameDataResults = concatRecallResults.filter((item) => {
|
||||
// 删除所有的标点符号与空格等,只对文本进行比较
|
||||
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
|
||||
if (set.has(str)) return false;
|
||||
set.add(str);
|
||||
return true;
|
||||
});
|
||||
|
||||
if (!rerank) {
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(
|
||||
filterSameDataResults.filter((item) => item.score >= similarity),
|
||||
maxTokens
|
||||
),
|
||||
tokenLen: tokens
|
||||
};
|
||||
}
|
||||
|
||||
// ReRank results
|
||||
const reRankResults = (
|
||||
await reRankSearchResult({
|
||||
query: rawQuery,
|
||||
data: filterSameDataResults
|
||||
})
|
||||
).filter((item) => item.score > similarity);
|
||||
|
||||
return {
|
||||
searchRes: filterResultsByMaxTokens(
|
||||
reRankResults.filter((item) => item.score >= similarity),
|
||||
maxTokens
|
||||
),
|
||||
tokenLen: tokens
|
||||
};
|
||||
}
|
||||
// ------------------ search end ------------------
|
@@ -1,6 +1,4 @@
|
||||
import { MongoDatasetData } from '@fastgpt/service/core/dataset/data/schema';
|
||||
import { cut } from '@node-rs/jieba';
|
||||
import { stopWords } from '@fastgpt/global/common/string/jieba';
|
||||
|
||||
/**
|
||||
* Same value judgment
|
||||
@@ -24,14 +22,3 @@ export async function hasSameValue({
|
||||
return Promise.reject('已经存在完全一致的数据');
|
||||
}
|
||||
}
|
||||
|
||||
export function jiebaSplit({ text }: { text: string }) {
|
||||
const tokens = cut(text, true);
|
||||
|
||||
return (
|
||||
tokens
|
||||
.map((item) => item.replace(/[^\u4e00-\u9fa5a-zA-Z0-9\s]/g, '').trim())
|
||||
.filter((item) => item && !stopWords.has(item))
|
||||
.join(' ') || ''
|
||||
);
|
||||
}
|
||||
|
@@ -136,7 +136,6 @@ ${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;
|
||||
stream: false
|
||||
});
|
||||
const answer = chatResponse.choices?.[0].message?.content || '';
|
||||
const totalTokens = chatResponse.usage?.total_tokens || 0;
|
||||
|
||||
const qaArr = formatSplitText(answer, text); // 格式化后的QA对
|
||||
|
||||
@@ -167,7 +166,8 @@ ${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;
|
||||
pushQABill({
|
||||
teamId: data.teamId,
|
||||
tmbId: data.tmbId,
|
||||
totalTokens,
|
||||
inputTokens: chatResponse.usage?.prompt_tokens || 0,
|
||||
outputTokens: chatResponse.usage?.completion_tokens || 0,
|
||||
billId: data.billId,
|
||||
model
|
||||
});
|
||||
|
@@ -129,7 +129,7 @@ export async function generateVector(): Promise<any> {
|
||||
}
|
||||
|
||||
// insert data to pg
|
||||
const { tokenLen } = await insertData2Dataset({
|
||||
const { tokens } = await insertData2Dataset({
|
||||
teamId: data.teamId,
|
||||
tmbId: data.tmbId,
|
||||
datasetId: data.datasetId,
|
||||
@@ -145,7 +145,7 @@ export async function generateVector(): Promise<any> {
|
||||
pushGenerateVectorBill({
|
||||
teamId: data.teamId,
|
||||
tmbId: data.tmbId,
|
||||
tokenLen: tokenLen,
|
||||
tokens,
|
||||
model: data.model,
|
||||
billId: data.billId
|
||||
});
|
||||
|
@@ -9,8 +9,9 @@ import type { ModuleDispatchProps } from '@fastgpt/global/core/module/type.d';
|
||||
import { replaceVariable } from '@fastgpt/global/common/string/tools';
|
||||
import { Prompt_CQJson } from '@/global/core/prompt/agent';
|
||||
import { FunctionModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
import { getCQModel } from '@/service/core/ai/model';
|
||||
import { ModelTypeEnum, getCQModel } from '@/service/core/ai/model';
|
||||
import { getHistories } from '../utils';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
|
||||
type Props = ModuleDispatchProps<{
|
||||
[ModuleInputKeyEnum.aiModel]: string;
|
||||
@@ -42,7 +43,7 @@ export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse
|
||||
|
||||
const chatHistories = getHistories(history, histories);
|
||||
|
||||
const { arg, tokens } = await (async () => {
|
||||
const { arg, inputTokens, outputTokens } = await (async () => {
|
||||
if (cqModel.toolChoice) {
|
||||
return toolChoice({
|
||||
...props,
|
||||
@@ -59,13 +60,21 @@ export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse
|
||||
|
||||
const result = agents.find((item) => item.key === arg?.type) || agents[agents.length - 1];
|
||||
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: cqModel.model,
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
type: ModelTypeEnum.cq
|
||||
});
|
||||
|
||||
return {
|
||||
[result.key]: result.value,
|
||||
[ModuleOutputKeyEnum.responseData]: {
|
||||
price: user.openaiAccount?.key ? 0 : cqModel.price * tokens,
|
||||
model: cqModel.name || '',
|
||||
price: user.openaiAccount?.key ? 0 : total,
|
||||
model: modelName,
|
||||
query: userChatInput,
|
||||
tokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cqList: agents,
|
||||
cqResult: result.value,
|
||||
contextTotalLen: chatHistories.length + 2
|
||||
@@ -140,7 +149,8 @@ ${systemPrompt}
|
||||
|
||||
return {
|
||||
arg,
|
||||
tokens: response.usage?.total_tokens || 0
|
||||
inputTokens: response.usage?.prompt_tokens || 0,
|
||||
outputTokens: response.usage?.completion_tokens || 0
|
||||
};
|
||||
} catch (error) {
|
||||
console.log(agentFunction.parameters);
|
||||
@@ -150,7 +160,8 @@ ${systemPrompt}
|
||||
|
||||
return {
|
||||
arg: {},
|
||||
tokens: 0
|
||||
inputTokens: 0,
|
||||
outputTokens: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -182,12 +193,12 @@ Human:${userChatInput}`
|
||||
stream: false
|
||||
});
|
||||
const answer = data.choices?.[0].message?.content || '';
|
||||
const totalTokens = data.usage?.total_tokens || 0;
|
||||
|
||||
const id = agents.find((item) => answer.includes(item.key))?.key || '';
|
||||
|
||||
return {
|
||||
tokens: totalTokens,
|
||||
inputTokens: data.usage?.prompt_tokens || 0,
|
||||
outputTokens: data.usage?.completion_tokens || 0,
|
||||
arg: { type: id }
|
||||
};
|
||||
}
|
||||
|
@@ -10,7 +10,8 @@ import { Prompt_ExtractJson } from '@/global/core/prompt/agent';
|
||||
import { replaceVariable } from '@fastgpt/global/common/string/tools';
|
||||
import { FunctionModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
import { getHistories } from '../utils';
|
||||
import { getExtractModel } from '@/service/core/ai/model';
|
||||
import { ModelTypeEnum, getExtractModel } from '@/service/core/ai/model';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
|
||||
type Props = ModuleDispatchProps<{
|
||||
[ModuleInputKeyEnum.history]?: ChatItemType[];
|
||||
@@ -42,7 +43,7 @@ export async function dispatchContentExtract(props: Props): Promise<Response> {
|
||||
const extractModel = getExtractModel(model);
|
||||
const chatHistories = getHistories(history, histories);
|
||||
|
||||
const { arg, tokens } = await (async () => {
|
||||
const { arg, inputTokens, outputTokens } = await (async () => {
|
||||
if (extractModel.toolChoice) {
|
||||
return toolChoice({
|
||||
...props,
|
||||
@@ -79,16 +80,24 @@ export async function dispatchContentExtract(props: Props): Promise<Response> {
|
||||
}
|
||||
}
|
||||
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: extractModel.model,
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
type: ModelTypeEnum.extract
|
||||
});
|
||||
|
||||
return {
|
||||
[ModuleOutputKeyEnum.success]: success ? true : undefined,
|
||||
[ModuleOutputKeyEnum.failed]: success ? undefined : true,
|
||||
[ModuleOutputKeyEnum.contextExtractFields]: JSON.stringify(arg),
|
||||
...arg,
|
||||
[ModuleOutputKeyEnum.responseData]: {
|
||||
price: user.openaiAccount?.key ? 0 : extractModel.price * tokens,
|
||||
model: extractModel.name || '',
|
||||
price: user.openaiAccount?.key ? 0 : total,
|
||||
model: modelName,
|
||||
query: content,
|
||||
tokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
extractDescription: description,
|
||||
extractResult: arg,
|
||||
contextTotalLen: chatHistories.length + 2
|
||||
@@ -181,10 +190,10 @@ ${description || '根据用户要求获取适当的 JSON 字符串。'}
|
||||
}
|
||||
})();
|
||||
|
||||
const tokens = response.usage?.total_tokens || 0;
|
||||
return {
|
||||
rawResponse: response?.choices?.[0]?.message?.tool_calls?.[0]?.function?.arguments || '',
|
||||
tokens,
|
||||
inputTokens: response.usage?.prompt_tokens || 0,
|
||||
outputTokens: response.usage?.completion_tokens || 0,
|
||||
arg
|
||||
};
|
||||
}
|
||||
@@ -223,7 +232,8 @@ Human: ${content}`
|
||||
stream: false
|
||||
});
|
||||
const answer = data.choices?.[0].message?.content || '';
|
||||
const totalTokens = data.usage?.total_tokens || 0;
|
||||
const inputTokens = data.usage?.prompt_tokens || 0;
|
||||
const outputTokens = data.usage?.completion_tokens || 0;
|
||||
|
||||
// parse response
|
||||
const start = answer.indexOf('{');
|
||||
@@ -232,7 +242,8 @@ Human: ${content}`
|
||||
if (start === -1 || end === -1)
|
||||
return {
|
||||
rawResponse: answer,
|
||||
tokens: totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
arg: {}
|
||||
};
|
||||
|
||||
@@ -244,13 +255,15 @@ Human: ${content}`
|
||||
try {
|
||||
return {
|
||||
rawResponse: answer,
|
||||
tokens: totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
arg: JSON.parse(jsonStr) as Record<string, any>
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
rawResponse: answer,
|
||||
tokens: totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
arg: {}
|
||||
};
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ import { sseResponseEventEnum } from '@fastgpt/service/common/response/constant'
|
||||
import { textAdaptGptResponse } from '@/utils/adapt';
|
||||
import { getAIApi } from '@fastgpt/service/core/ai/config';
|
||||
import type { ChatCompletion, StreamChatType } from '@fastgpt/global/core/ai/type.d';
|
||||
import { countModelPrice } from '@/service/support/wallet/bill/utils';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
import type { ChatModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
import { postTextCensor } from '@/service/common/censor';
|
||||
import { ChatCompletionRequestMessageRoleEnum } from '@fastgpt/global/core/ai/constant';
|
||||
@@ -151,7 +151,7 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
|
||||
}
|
||||
);
|
||||
|
||||
const { answerText, totalTokens, completeMessages } = await (async () => {
|
||||
const { answerText, inputTokens, outputTokens, completeMessages } = await (async () => {
|
||||
if (stream) {
|
||||
// sse response
|
||||
const { answer } = await streamResponse({
|
||||
@@ -165,21 +165,26 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
|
||||
value: answer
|
||||
});
|
||||
|
||||
const totalTokens = countMessagesTokens({
|
||||
messages: completeMessages
|
||||
});
|
||||
|
||||
targetResponse({ res, detail, outputs });
|
||||
|
||||
return {
|
||||
answerText: answer,
|
||||
totalTokens,
|
||||
inputTokens: countMessagesTokens({
|
||||
messages: filterMessages
|
||||
}),
|
||||
outputTokens: countMessagesTokens({
|
||||
messages: [
|
||||
{
|
||||
obj: ChatRoleEnum.AI,
|
||||
value: answer
|
||||
}
|
||||
]
|
||||
}),
|
||||
completeMessages
|
||||
};
|
||||
} else {
|
||||
const unStreamResponse = response as ChatCompletion;
|
||||
const answer = unStreamResponse.choices?.[0]?.message?.content || '';
|
||||
const totalTokens = unStreamResponse.usage?.total_tokens || 0;
|
||||
|
||||
const completeMessages = filterMessages.concat({
|
||||
obj: ChatRoleEnum.AI,
|
||||
@@ -188,20 +193,27 @@ export const dispatchChatCompletion = async (props: ChatProps): Promise<ChatResp
|
||||
|
||||
return {
|
||||
answerText: answer,
|
||||
totalTokens,
|
||||
inputTokens: unStreamResponse.usage?.prompt_tokens || 0,
|
||||
outputTokens: unStreamResponse.usage?.completion_tokens || 0,
|
||||
completeMessages
|
||||
};
|
||||
}
|
||||
})();
|
||||
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model,
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
type: ModelTypeEnum.chat
|
||||
});
|
||||
|
||||
return {
|
||||
answerText,
|
||||
responseData: {
|
||||
price: user.openaiAccount?.key
|
||||
? 0
|
||||
: countModelPrice({ model, tokens: totalTokens, type: ModelTypeEnum.chat }),
|
||||
model: modelConstantsData.name,
|
||||
tokens: totalTokens,
|
||||
price: user.openaiAccount?.key ? 0 : total,
|
||||
model: modelName,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
query: userChatInput,
|
||||
maxToken: max_tokens,
|
||||
quoteList: filterQuoteQA,
|
||||
@@ -227,8 +239,7 @@ function filterQuote({
|
||||
a: item.a,
|
||||
source: item.sourceName,
|
||||
sourceId: String(item.sourceId || 'UnKnow'),
|
||||
index: index + 1,
|
||||
score: item.score?.toFixed(4)
|
||||
index: index + 1
|
||||
});
|
||||
}
|
||||
|
||||
|
@@ -1,13 +1,12 @@
|
||||
import type { moduleDispatchResType } from '@fastgpt/global/core/chat/type.d';
|
||||
import { countModelPrice } from '@/service/support/wallet/bill/utils';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
import type { SelectedDatasetType } from '@fastgpt/global/core/module/api.d';
|
||||
import type { SearchDataResponseItemType } from '@fastgpt/global/core/dataset/type';
|
||||
import type { ModuleDispatchProps } from '@fastgpt/global/core/module/type.d';
|
||||
import { ModelTypeEnum } from '@/service/core/ai/model';
|
||||
import { searchDatasetData } from '@/service/core/dataset/data/pg';
|
||||
import { searchDatasetData } from '@/service/core/dataset/data/controller';
|
||||
import { ModuleInputKeyEnum, ModuleOutputKeyEnum } from '@fastgpt/global/core/module/constants';
|
||||
import { DatasetSearchModeEnum } from '@fastgpt/global/core/dataset/constant';
|
||||
import { searchQueryExtension } from '@fastgpt/service/core/ai/functions/queryExtension';
|
||||
|
||||
type DatasetSearchProps = ModuleDispatchProps<{
|
||||
[ModuleInputKeyEnum.datasetSelectList]: SelectedDatasetType;
|
||||
@@ -15,6 +14,7 @@ type DatasetSearchProps = ModuleDispatchProps<{
|
||||
[ModuleInputKeyEnum.datasetLimit]: number;
|
||||
[ModuleInputKeyEnum.datasetSearchMode]: `${DatasetSearchModeEnum}`;
|
||||
[ModuleInputKeyEnum.userChatInput]: string;
|
||||
[ModuleInputKeyEnum.datasetSearchUsingReRank]: boolean;
|
||||
}>;
|
||||
export type DatasetSearchResponse = {
|
||||
[ModuleOutputKeyEnum.responseData]: moduleDispatchResType;
|
||||
@@ -27,7 +27,7 @@ export async function dispatchDatasetSearch(
|
||||
props: DatasetSearchProps
|
||||
): Promise<DatasetSearchResponse> {
|
||||
const {
|
||||
inputs: { datasets = [], similarity = 0.4, limit = 5, searchMode, userChatInput }
|
||||
inputs: { datasets = [], similarity, limit = 1500, usingReRank, searchMode, userChatInput }
|
||||
} = props as DatasetSearchProps;
|
||||
|
||||
if (!Array.isArray(datasets)) {
|
||||
@@ -52,14 +52,21 @@ export async function dispatchDatasetSearch(
|
||||
const concatQueries = [userChatInput];
|
||||
|
||||
// start search
|
||||
const { searchRes, tokenLen } = await searchDatasetData({
|
||||
const { searchRes, tokens, usingSimilarityFilter } = await searchDatasetData({
|
||||
rawQuery: userChatInput,
|
||||
queries: concatQueries,
|
||||
model: vectorModel.model,
|
||||
similarity,
|
||||
limit,
|
||||
datasetIds: datasets.map((item) => item.datasetId),
|
||||
searchMode
|
||||
searchMode,
|
||||
usingReRank
|
||||
});
|
||||
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: vectorModel.model,
|
||||
inputLen: tokens,
|
||||
type: ModelTypeEnum.vector
|
||||
});
|
||||
|
||||
return {
|
||||
@@ -67,17 +74,14 @@ export async function dispatchDatasetSearch(
|
||||
unEmpty: searchRes.length > 0 ? true : undefined,
|
||||
quoteQA: searchRes,
|
||||
responseData: {
|
||||
price: countModelPrice({
|
||||
model: vectorModel.model,
|
||||
tokens: tokenLen,
|
||||
type: ModelTypeEnum.vector
|
||||
}),
|
||||
price: total,
|
||||
query: concatQueries.join('\n'),
|
||||
model: vectorModel.name,
|
||||
tokens: tokenLen,
|
||||
similarity,
|
||||
model: modelName,
|
||||
inputTokens: tokens,
|
||||
similarity: usingSimilarityFilter ? similarity : undefined,
|
||||
limit,
|
||||
searchMode
|
||||
searchMode,
|
||||
searchUsingReRank: usingReRank
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@@ -4,7 +4,8 @@ import { ModuleInputKeyEnum, ModuleOutputKeyEnum } from '@fastgpt/global/core/mo
|
||||
import { getHistories } from '../utils';
|
||||
import { getAIApi } from '@fastgpt/service/core/ai/config';
|
||||
import { replaceVariable } from '@fastgpt/global/common/string/tools';
|
||||
import { getExtractModel } from '@/service/core/ai/model';
|
||||
import { ModelTypeEnum, getExtractModel } from '@/service/core/ai/model';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
|
||||
type Props = ModuleDispatchProps<{
|
||||
[ModuleInputKeyEnum.aiModel]: string;
|
||||
@@ -75,13 +76,22 @@ A: ${systemPrompt}
|
||||
// );
|
||||
// console.log(answer);
|
||||
|
||||
const tokens = result.usage?.total_tokens || 0;
|
||||
const inputTokens = result.usage?.prompt_tokens || 0;
|
||||
const outputTokens = result.usage?.completion_tokens || 0;
|
||||
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: extractModel.model,
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
type: ModelTypeEnum.extract
|
||||
});
|
||||
|
||||
return {
|
||||
[ModuleOutputKeyEnum.responseData]: {
|
||||
price: extractModel.price * tokens,
|
||||
model: extractModel.name || '',
|
||||
tokens,
|
||||
price: total,
|
||||
model: modelName,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
query: userChatInput,
|
||||
textOutput: answer
|
||||
},
|
||||
|
@@ -1,11 +1,11 @@
|
||||
import { startQueue } from './utils/tools';
|
||||
import { PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
|
||||
import { initPg } from '@fastgpt/service/common/pg';
|
||||
import { MongoUser } from '@fastgpt/service/support/user/schema';
|
||||
import { connectMongo } from '@fastgpt/service/common/mongo/init';
|
||||
import { hashStr } from '@fastgpt/global/common/string/tools';
|
||||
import { createDefaultTeam } from '@fastgpt/service/support/user/team/controller';
|
||||
import { exit } from 'process';
|
||||
import { initVectorStore } from '@fastgpt/service/common/vectorStore/controller';
|
||||
|
||||
/**
|
||||
* connect MongoDB and init data
|
||||
@@ -14,7 +14,7 @@ export function connectToDatabase(): Promise<void> {
|
||||
return connectMongo({
|
||||
beforeHook: () => {},
|
||||
afterHook: () => {
|
||||
initPg();
|
||||
initVectorStore();
|
||||
// start queue
|
||||
startQueue();
|
||||
return initRootUser();
|
||||
|
22
projects/app/src/service/support/wallet/bill/controller.ts
Normal file
22
projects/app/src/service/support/wallet/bill/controller.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { ConcatBillProps, CreateBillProps } from '@fastgpt/global/support/wallet/bill/api';
|
||||
import { addLog } from '@fastgpt/service/common/system/log';
|
||||
import { POST } from '@fastgpt/service/common/api/plusRequest';
|
||||
|
||||
export function createBill(data: CreateBillProps) {
|
||||
if (!global.systemEnv?.pluginBaseUrl) return;
|
||||
if (data.total === 0) {
|
||||
addLog.info('0 Bill', data);
|
||||
}
|
||||
try {
|
||||
POST('/support/wallet/bill/createBill', data);
|
||||
} catch (error) {}
|
||||
}
|
||||
export function concatBill(data: ConcatBillProps) {
|
||||
if (!global.systemEnv?.pluginBaseUrl) return;
|
||||
if (data.total === 0) {
|
||||
addLog.info('0 Bill', data);
|
||||
}
|
||||
try {
|
||||
POST('/support/wallet/bill/concatBill', data);
|
||||
} catch (error) {}
|
||||
}
|
@@ -1,30 +1,11 @@
|
||||
import { BillSourceEnum, PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
|
||||
import { getAudioSpeechModel, getQAModel, getVectorModel } from '@/service/core/ai/model';
|
||||
import { BillSourceEnum } from '@fastgpt/global/support/wallet/bill/constants';
|
||||
import { ModelTypeEnum } from '@/service/core/ai/model';
|
||||
import type { ChatHistoryItemResType } from '@fastgpt/global/core/chat/type.d';
|
||||
import { formatPrice } from '@fastgpt/global/support/wallet/bill/tools';
|
||||
import { formatStorePrice2Read } from '@fastgpt/global/support/wallet/bill/tools';
|
||||
import { addLog } from '@fastgpt/service/common/system/log';
|
||||
import type { ConcatBillProps, CreateBillProps } from '@fastgpt/global/support/wallet/bill/api.d';
|
||||
import { POST } from '@fastgpt/service/common/api/plusRequest';
|
||||
import { PostReRankProps } from '@fastgpt/global/core/ai/api';
|
||||
|
||||
export function createBill(data: CreateBillProps) {
|
||||
if (!global.systemEnv?.pluginBaseUrl) return;
|
||||
if (data.total === 0) {
|
||||
addLog.info('0 Bill', data);
|
||||
}
|
||||
try {
|
||||
POST('/support/wallet/bill/createBill', data);
|
||||
} catch (error) {}
|
||||
}
|
||||
export function concatBill(data: ConcatBillProps) {
|
||||
if (!global.systemEnv?.pluginBaseUrl) return;
|
||||
if (data.total === 0) {
|
||||
addLog.info('0 Bill', data);
|
||||
}
|
||||
try {
|
||||
POST('/support/wallet/bill/concatBill', data);
|
||||
} catch (error) {}
|
||||
}
|
||||
import { createBill, concatBill } from './controller';
|
||||
import { formatModelPrice2Store } from '@/service/support/wallet/bill/utils';
|
||||
|
||||
export const pushChatBill = ({
|
||||
appName,
|
||||
@@ -54,14 +35,15 @@ export const pushChatBill = ({
|
||||
moduleName: item.moduleName,
|
||||
amount: item.price || 0,
|
||||
model: item.model,
|
||||
tokenLen: item.tokens
|
||||
inputTokens: item.inputTokens,
|
||||
outputTokens: item.outputTokens
|
||||
}))
|
||||
});
|
||||
addLog.info(`finish completions`, {
|
||||
source,
|
||||
teamId,
|
||||
tmbId,
|
||||
price: formatPrice(total)
|
||||
price: formatStorePrice2Read(total)
|
||||
});
|
||||
return { total };
|
||||
};
|
||||
@@ -70,26 +52,32 @@ export const pushQABill = async ({
|
||||
teamId,
|
||||
tmbId,
|
||||
model,
|
||||
totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
billId
|
||||
}: {
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
model: string;
|
||||
totalTokens: number;
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
billId: string;
|
||||
}) => {
|
||||
// 获取模型单价格
|
||||
const unitPrice = getQAModel(model).price;
|
||||
// 计算价格
|
||||
const total = unitPrice * totalTokens;
|
||||
const { total } = formatModelPrice2Store({
|
||||
model,
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
type: ModelTypeEnum.qa
|
||||
});
|
||||
|
||||
concatBill({
|
||||
billId,
|
||||
teamId,
|
||||
tmbId,
|
||||
total,
|
||||
tokens: totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
listIndex: 1
|
||||
});
|
||||
|
||||
@@ -100,22 +88,24 @@ export const pushGenerateVectorBill = ({
|
||||
billId,
|
||||
teamId,
|
||||
tmbId,
|
||||
tokenLen,
|
||||
tokens,
|
||||
model,
|
||||
source = BillSourceEnum.fastgpt
|
||||
}: {
|
||||
billId?: string;
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
tokenLen: number;
|
||||
tokens: number;
|
||||
model: string;
|
||||
source?: `${BillSourceEnum}`;
|
||||
}) => {
|
||||
// 计算价格. 至少为1
|
||||
const vectorModel = getVectorModel(model);
|
||||
const unitPrice = vectorModel.price || 0.2;
|
||||
let total = unitPrice * tokenLen;
|
||||
total = total > 1 ? total : 1;
|
||||
let { total, modelName } = formatModelPrice2Store({
|
||||
model,
|
||||
inputLen: tokens,
|
||||
type: ModelTypeEnum.vector
|
||||
});
|
||||
|
||||
total = total < 1 ? 1 : total;
|
||||
|
||||
// 插入 Bill 记录
|
||||
if (billId) {
|
||||
@@ -124,22 +114,22 @@ export const pushGenerateVectorBill = ({
|
||||
tmbId,
|
||||
total,
|
||||
billId,
|
||||
tokens: tokenLen,
|
||||
inputTokens: tokens,
|
||||
listIndex: 0
|
||||
});
|
||||
} else {
|
||||
createBill({
|
||||
teamId,
|
||||
tmbId,
|
||||
appName: '索引生成',
|
||||
appName: 'wallet.moduleName.index',
|
||||
total,
|
||||
source,
|
||||
list: [
|
||||
{
|
||||
moduleName: '索引生成',
|
||||
moduleName: 'wallet.moduleName.index',
|
||||
amount: total,
|
||||
model: vectorModel.name,
|
||||
tokenLen
|
||||
model: modelName,
|
||||
inputTokens: tokens
|
||||
}
|
||||
]
|
||||
});
|
||||
@@ -148,28 +138,37 @@ export const pushGenerateVectorBill = ({
|
||||
};
|
||||
|
||||
export const pushQuestionGuideBill = ({
|
||||
tokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
teamId,
|
||||
tmbId
|
||||
}: {
|
||||
tokens: number;
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
}) => {
|
||||
const qgModel = global.qgModels[0];
|
||||
const total = qgModel.price * tokens;
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
inputLen: inputTokens,
|
||||
outputLen: outputTokens,
|
||||
model: qgModel.model,
|
||||
type: ModelTypeEnum.qg
|
||||
});
|
||||
|
||||
createBill({
|
||||
teamId,
|
||||
tmbId,
|
||||
appName: '下一步指引',
|
||||
appName: 'wallet.bill.Next Step Guide',
|
||||
total,
|
||||
source: BillSourceEnum.fastgpt,
|
||||
list: [
|
||||
{
|
||||
moduleName: '下一步指引',
|
||||
moduleName: 'wallet.bill.Next Step Guide',
|
||||
amount: total,
|
||||
model: qgModel.name,
|
||||
tokenLen: tokens
|
||||
model: modelName,
|
||||
inputTokens,
|
||||
outputTokens
|
||||
}
|
||||
]
|
||||
});
|
||||
@@ -178,20 +177,24 @@ export const pushQuestionGuideBill = ({
|
||||
export function pushAudioSpeechBill({
|
||||
appName = 'wallet.bill.Audio Speech',
|
||||
model,
|
||||
textLength,
|
||||
textLen,
|
||||
teamId,
|
||||
tmbId,
|
||||
source = BillSourceEnum.fastgpt
|
||||
}: {
|
||||
appName?: string;
|
||||
model: string;
|
||||
textLength: number;
|
||||
textLen: number;
|
||||
teamId: string;
|
||||
tmbId: string;
|
||||
source: `${BillSourceEnum}`;
|
||||
}) {
|
||||
const modelData = getAudioSpeechModel(model);
|
||||
const total = modelData.price * textLength;
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model,
|
||||
inputLen: textLen,
|
||||
type: ModelTypeEnum.audioSpeech
|
||||
});
|
||||
|
||||
createBill({
|
||||
teamId,
|
||||
tmbId,
|
||||
@@ -202,8 +205,8 @@ export function pushAudioSpeechBill({
|
||||
{
|
||||
moduleName: appName,
|
||||
amount: total,
|
||||
model: modelData.name,
|
||||
tokenLen: textLength
|
||||
model: modelName,
|
||||
textLen
|
||||
}
|
||||
]
|
||||
});
|
||||
@@ -218,11 +221,16 @@ export function pushWhisperBill({
|
||||
tmbId: string;
|
||||
duration: number;
|
||||
}) {
|
||||
const modelData = global.whisperModel;
|
||||
const whisperModel = global.whisperModel;
|
||||
|
||||
if (!modelData) return;
|
||||
if (!whisperModel) return;
|
||||
|
||||
const total = ((modelData.price * duration) / 60) * PRICE_SCALE;
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: whisperModel.model,
|
||||
inputLen: duration,
|
||||
type: ModelTypeEnum.whisper,
|
||||
multiple: 60
|
||||
});
|
||||
|
||||
const name = 'wallet.bill.Whisper';
|
||||
|
||||
@@ -236,8 +244,8 @@ export function pushWhisperBill({
|
||||
{
|
||||
moduleName: name,
|
||||
amount: total,
|
||||
model: modelData.name,
|
||||
tokenLen: duration
|
||||
model: modelName,
|
||||
duration
|
||||
}
|
||||
]
|
||||
});
|
||||
@@ -254,13 +262,16 @@ export function pushReRankBill({
|
||||
source: `${BillSourceEnum}`;
|
||||
inputs: PostReRankProps['inputs'];
|
||||
}) {
|
||||
const model = global.reRankModels[0];
|
||||
if (!model) return { total: 0 };
|
||||
const reRankModel = global.reRankModels[0];
|
||||
if (!reRankModel) return { total: 0 };
|
||||
|
||||
const textLength = inputs.reduce((sum, item) => sum + item.text.length, 0);
|
||||
const ratio = textLength / 1000;
|
||||
const textLen = inputs.reduce((sum, item) => sum + item.text.length, 0);
|
||||
|
||||
const total = Math.ceil(model.price * PRICE_SCALE * ratio);
|
||||
const { total, modelName } = formatModelPrice2Store({
|
||||
model: reRankModel.model,
|
||||
inputLen: textLen,
|
||||
type: ModelTypeEnum.rerank
|
||||
});
|
||||
const name = 'wallet.bill.ReRank';
|
||||
|
||||
createBill({
|
||||
@@ -273,8 +284,8 @@ export function pushReRankBill({
|
||||
{
|
||||
moduleName: name,
|
||||
amount: total,
|
||||
model: model.name,
|
||||
tokenLen: textLength
|
||||
model: modelName,
|
||||
textLen
|
||||
}
|
||||
]
|
||||
});
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import { ModelTypeEnum, getModelMap } from '@/service/core/ai/model';
|
||||
import { AuthUserTypeEnum } from '@fastgpt/global/support/permission/constant';
|
||||
import { BillSourceEnum } from '@fastgpt/global/support/wallet/bill/constants';
|
||||
import { BillSourceEnum, PRICE_SCALE } from '@fastgpt/global/support/wallet/bill/constants';
|
||||
|
||||
export function authType2BillSource({
|
||||
authType,
|
||||
@@ -17,16 +17,38 @@ export function authType2BillSource({
|
||||
return BillSourceEnum.fastgpt;
|
||||
}
|
||||
|
||||
export const countModelPrice = ({
|
||||
export const formatModelPrice2Store = ({
|
||||
model,
|
||||
tokens,
|
||||
type
|
||||
inputLen = 0,
|
||||
outputLen = 0,
|
||||
type,
|
||||
multiple = 1000
|
||||
}: {
|
||||
model: string;
|
||||
tokens: number;
|
||||
inputLen: number;
|
||||
outputLen?: number;
|
||||
type: `${ModelTypeEnum}`;
|
||||
multiple?: number;
|
||||
}) => {
|
||||
const modelData = getModelMap?.[type]?.(model);
|
||||
if (!modelData) return 0;
|
||||
return modelData.price * tokens;
|
||||
if (!modelData)
|
||||
return {
|
||||
inputTotal: 0,
|
||||
outputTotal: 0,
|
||||
total: 0,
|
||||
modelName: ''
|
||||
};
|
||||
const inputTotal = modelData.inputPrice
|
||||
? Math.ceil(modelData.inputPrice * (inputLen / multiple) * PRICE_SCALE)
|
||||
: 0;
|
||||
const outputTotal = modelData.outputPrice
|
||||
? Math.ceil(modelData.outputPrice * (outputLen / multiple) * PRICE_SCALE)
|
||||
: 0;
|
||||
|
||||
return {
|
||||
modelName: modelData.name,
|
||||
inputTotal: inputTotal,
|
||||
outputTotal: outputTotal,
|
||||
total: inputTotal + outputTotal
|
||||
};
|
||||
};
|
||||
|
Reference in New Issue
Block a user