This commit is contained in:
Archer
2023-11-09 09:46:57 +08:00
committed by GitHub
parent 661ee79943
commit 8bb5588305
402 changed files with 9899 additions and 5967 deletions

View File

@@ -1,15 +1,17 @@
import {
defaultAudioSpeechModels,
defaultChatModels,
defaultCQModels,
defaultExtractModels,
defaultQAModels,
defaultQGModels,
defaultVectorModels
} from '@/constants/model';
} from '@fastgpt/global/core/ai/model';
export const getChatModel = (model?: string) => {
return (
(global.chatModels || defaultChatModels).find((item) => item.model === model) ||
global.chatModels?.[0] ||
defaultChatModels[0]
);
};
@@ -50,6 +52,14 @@ export const getVectorModel = (model?: string) => {
);
};
export function getAudioSpeechModel(model?: string) {
return (
global.audioSpeechModels.find((item) => item.model === model) ||
global.audioSpeechModels?.[0] ||
defaultAudioSpeechModels[0]
);
}
export enum ModelTypeEnum {
chat = 'chat',
qa = 'qa',

View File

@@ -0,0 +1,70 @@
import { getAIApi } from '@fastgpt/service/core/ai/config';
export type GetVectorProps = {
model: string;
input: string | string[];
};
// text to vector
export async function getVectorsByText({
model = 'text-embedding-ada-002',
input
}: GetVectorProps) {
try {
if (typeof input === 'string' && !input) {
return Promise.reject({
code: 500,
message: 'input is empty'
});
} else if (Array.isArray(input)) {
for (let i = 0; i < input.length; i++) {
if (!input[i]) {
return Promise.reject({
code: 500,
message: 'input array is empty'
});
}
}
}
// 获取 chatAPI
const ai = getAIApi();
// 把输入的内容转成向量
const result = await ai.embeddings
.create({
model,
input
})
.then(async (res) => {
if (!res.data) {
return Promise.reject('Embedding API 404');
}
if (!res?.data?.[0]?.embedding) {
console.log(res?.data);
// @ts-ignore
return Promise.reject(res.data?.err?.message || 'Embedding API Error');
}
return {
tokenLen: res.usage.total_tokens || 0,
vectors: await Promise.all(res.data.map((item) => unityDimensional(item.embedding)))
};
});
return result;
} catch (error) {
console.log(`Embedding Error`, error);
return Promise.reject(error);
}
}
function unityDimensional(vector: number[]) {
if (vector.length > 1536) return Promise.reject('向量维度不能超过 1536');
let resultVector = vector;
const vectorLen = vector.length;
const zeroVector = new Array(1536 - vectorLen).fill(0);
return resultVector.concat(zeroVector);
}

View File

@@ -0,0 +1,168 @@
import { PgDatasetTableName } from '@fastgpt/global/core/dataset/constant';
import { getVectorsByText } from '@/service/core/ai/vector';
import { PgClient } from '@fastgpt/service/common/pg';
import { delay } from '@/utils/tools';
import {
DatasetDataItemType,
PgDataItemType,
PgRawDataItemType
} from '@fastgpt/global/core/dataset/type';
import { MongoDatasetCollection } from '@fastgpt/service/core/dataset/collection/schema';
export async function formatPgRawData(data: PgRawDataItemType) {
return {
id: data.id,
q: data.q,
a: data.a,
teamId: data.team_id,
tmbId: data.tmb_id,
datasetId: data.dataset_id,
collectionId: data.collection_id
};
}
/* get */
export async function getDatasetPgData({ id }: { id: string }): Promise<PgDataItemType> {
const { rows } = await PgClient.select<PgRawDataItemType>(PgDatasetTableName, {
fields: ['id', 'q', 'a', 'team_id', 'tmb_id', 'dataset_id', 'collection_id'],
where: [['id', id]],
limit: 1
});
const row = rows[0];
if (!row) return Promise.reject('Data not found');
return formatPgRawData(row);
}
export async function getPgDataWithCollection({
pgDataList
}: {
pgDataList: PgRawDataItemType[];
}): Promise<DatasetDataItemType[]> {
const collections = await MongoDatasetCollection.find(
{
_id: { $in: pgDataList.map((item) => item.collection_id) }
},
'_id name datasetId metadata'
).lean();
return pgDataList.map((item) => {
const collection = collections.find(
(collection) => String(collection._id) === item.collection_id
);
return {
id: item.id,
q: item.q,
a: item.a,
datasetId: collection?.datasetId || '',
collectionId: item.collection_id,
sourceName: collection?.name || '',
sourceId: collection?.metadata?.fileId || collection?.metadata?.rawLink
};
});
}
type Props = {
q: string;
a?: string;
model: string;
};
/**
* update a or a
*/
export async function updateData2Dataset({ dataId, q, a = '', model }: Props & { dataId: string }) {
const { vectors = [], tokenLen = 0 } = await (async () => {
if (q) {
return getVectorsByText({
input: [q],
model
});
}
return { vectors: [[]], tokenLen: 0 };
})();
await PgClient.update(PgDatasetTableName, {
where: [['id', dataId]],
values: [
{ key: 'a', value: a.replace(/'/g, '"') },
...(q
? [
{ key: 'q', value: q.replace(/'/g, '"') },
{ key: 'vector', value: `[${vectors[0]}]` }
]
: [])
]
});
return {
vectors,
tokenLen
};
}
/* insert data to pg */
export async function insertData2Dataset({
teamId,
tmbId,
datasetId,
collectionId,
q,
a = '',
model
}: Props & {
teamId: string;
tmbId: string;
datasetId: string;
collectionId: string;
}) {
if (!q || !datasetId || !collectionId || !model) {
return Promise.reject('q, datasetId, collectionId, model is required');
}
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [q]
});
let retry = 2;
async function insertPg(): Promise<string> {
try {
const { rows } = await PgClient.insert(PgDatasetTableName, {
values: [
[
{ key: 'vector', value: `[${vectors[0]}]` },
{ key: 'team_id', value: String(teamId) },
{ key: 'tmb_id', value: String(tmbId) },
{ key: 'q', value: q },
{ key: 'a', value: a },
{ key: 'dataset_id', value: datasetId },
{ key: 'collection_id', value: collectionId }
]
]
});
return rows[0].id;
} catch (error) {
if (--retry < 0) {
return Promise.reject(error);
}
await delay(500);
return insertPg();
}
}
const insertId = await insertPg();
return {
insertId,
tokenLen,
vectors
};
}
/**
* delete data by collectionIds
*/
export async function delDataByCollectionId({ collectionIds }: { collectionIds: string[] }) {
const ids = collectionIds.map((item) => String(item));
return PgClient.delete(PgDatasetTableName, {
where: [`collection_id IN ('${ids.join("','")}')`]
});
}

View File

@@ -1,7 +1,11 @@
import { PgDatasetTableName } from '@/constants/plugin';
import { getVector } from '@/pages/api/openapi/plugin/vector';
import { PgClient } from '@/service/pg';
import { delay } from '@/utils/tools';
import { PgDatasetTableName } from '@fastgpt/global/core/dataset/constant';
import {
SearchDataResponseItemType,
SearchDataResultItemType
} from '@fastgpt/global/core/dataset/type';
import { PgClient } from '@fastgpt/service/common/pg';
import { getVectorsByText } from '../../ai/vector';
import { getPgDataWithCollection } from './controller';
/**
* Same value judgment
@@ -27,99 +31,6 @@ export async function hasSameValue({
}
}
type Props = {
userId: string;
q: string;
a?: string;
model: string;
};
export async function insertData2Dataset({
userId,
datasetId,
collectionId,
q,
a = '',
model,
billId
}: Props & {
datasetId: string;
collectionId: string;
billId?: string;
}) {
if (!q || !datasetId || !collectionId || !model) {
return Promise.reject('q, datasetId, collectionId, model is required');
}
const { vectors } = await getVector({
model,
input: [q],
userId,
billId
});
let retry = 2;
async function insertPg(): Promise<string> {
try {
const { rows } = await PgClient.insert(PgDatasetTableName, {
values: [
[
{ key: 'vector', value: `[${vectors[0]}]` },
{ key: 'user_id', value: userId },
{ key: 'q', value: q },
{ key: 'a', value: a },
{ key: 'dataset_id', value: datasetId },
{ key: 'collection_id', value: collectionId }
]
]
});
return rows[0].id;
} catch (error) {
if (--retry < 0) {
return Promise.reject(error);
}
await delay(500);
return insertPg();
}
}
return insertPg();
}
/**
* update a or a
*/
export async function updateData2Dataset({
dataId,
userId,
q,
a = '',
model
}: Props & { dataId: string }) {
const { vectors = [] } = await (async () => {
if (q) {
return getVector({
userId,
input: [q],
model
});
}
return { vectors: [[]] };
})();
await PgClient.update(PgDatasetTableName, {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'a', value: a.replace(/'/g, '"') },
...(q
? [
{ key: 'q', value: q.replace(/'/g, '"') },
{ key: 'vector', value: `[${vectors[0]}]` }
]
: [])
]
});
}
/**
* count one collection amount of total data
*/
@@ -148,18 +59,46 @@ export async function countCollectionData({
return values;
}
/**
* delete data by collectionIds
*/
export async function delDataByCollectionId({
userId,
collectionIds
export async function searchDatasetData({
text,
model,
similarity = 0,
limit,
datasetIds = []
}: {
userId: string;
collectionIds: string[];
text: string;
model: string;
similarity?: number;
limit: number;
datasetIds: string[];
}) {
const ids = collectionIds.map((item) => String(item));
return PgClient.delete(PgDatasetTableName, {
where: [['user_id', userId], 'AND', `collection_id IN ('${ids.join("','")}')`]
const { vectors, tokenLen } = await getVectorsByText({
model,
input: [text]
});
const results: any = await PgClient.query(
`BEGIN;
SET LOCAL hnsw.ef_search = ${global.systemEnv.pgHNSWEfSearch || 100};
select id, q, a, collection_id, (vector <#> '[${
vectors[0]
}]') * -1 AS score from ${PgDatasetTableName} where dataset_id IN (${datasetIds
.map((id) => `'${String(id)}'`)
.join(',')}) AND vector <#> '[${vectors[0]}]' < -${similarity} order by vector <#> '[${
vectors[0]
}]' limit ${limit};
COMMIT;`
);
const rows = results?.[2]?.rows as SearchDataResultItemType[];
const collectionsData = await getPgDataWithCollection({ pgDataList: rows });
const searchRes: SearchDataResponseItemType[] = collectionsData.map((item, index) => ({
...item,
score: rows[index].score
}));
return {
searchRes,
tokenLen
};
}