mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 05:12:39 +00:00
v4.6.9-alpha (#918)
Co-authored-by: Mufei <327958099@qq.com> Co-authored-by: heheer <71265218+newfish-cmyk@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import { VectorModelItemType } from '@fastgpt/global/core/ai/model.d';
|
||||
import { getAIApi } from '../config';
|
||||
import { replaceValidChars } from '../../chat/utils';
|
||||
import { countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
|
||||
type GetVectorProps = {
|
||||
model: VectorModelItemType;
|
||||
@@ -37,7 +37,7 @@ export async function getVectorsByText({ model, input }: GetVectorProps) {
|
||||
}
|
||||
|
||||
return {
|
||||
charsLength: replaceValidChars(input).length,
|
||||
tokens: countPromptTokens(input),
|
||||
vectors: await Promise.all(res.data.map((item) => unityDimensional(item.embedding)))
|
||||
};
|
||||
});
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import type { ChatMessageItemType } from '@fastgpt/global/core/ai/type.d';
|
||||
import { getAIApi } from '../config';
|
||||
import { countGptMessagesChars } from '../../chat/utils';
|
||||
import { countGptMessagesTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
|
||||
export const Prompt_QuestionGuide = `我不太清楚问你什么问题,请帮我生成 3 个问题,引导我继续提问。问题的长度应小于20个字符,按 JSON 格式返回: ["问题1", "问题2", "问题3"]`;
|
||||
|
||||
@@ -34,12 +34,12 @@ export async function createQuestionGuide({
|
||||
const start = answer.indexOf('[');
|
||||
const end = answer.lastIndexOf(']');
|
||||
|
||||
const charsLength = countGptMessagesChars(concatMessages);
|
||||
const tokens = countGptMessagesTokens(concatMessages);
|
||||
|
||||
if (start === -1 || end === -1) {
|
||||
return {
|
||||
result: [],
|
||||
charsLength: 0
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
|
||||
@@ -51,12 +51,12 @@ export async function createQuestionGuide({
|
||||
try {
|
||||
return {
|
||||
result: JSON.parse(jsonStr),
|
||||
charsLength
|
||||
tokens
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
result: [],
|
||||
charsLength: 0
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import { replaceVariable } from '@fastgpt/global/common/string/tools';
|
||||
import { getAIApi } from '../config';
|
||||
import { ChatItemType } from '@fastgpt/global/core/chat/type';
|
||||
import { countGptMessagesChars } from '../../chat/utils';
|
||||
import { countGptMessagesTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
|
||||
/*
|
||||
query extension - 问题扩展
|
||||
@@ -106,7 +106,7 @@ export const queryExtension = async ({
|
||||
rawQuery: string;
|
||||
extensionQueries: string[];
|
||||
model: string;
|
||||
charsLength: number;
|
||||
tokens: number;
|
||||
}> => {
|
||||
const systemFewShot = chatBg
|
||||
? `Q: 对话背景。
|
||||
@@ -148,7 +148,7 @@ A: ${chatBg}
|
||||
rawQuery: query,
|
||||
extensionQueries: [],
|
||||
model,
|
||||
charsLength: 0
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ A: ${chatBg}
|
||||
rawQuery: query,
|
||||
extensionQueries: queries,
|
||||
model,
|
||||
charsLength: countGptMessagesChars(messages)
|
||||
tokens: countGptMessagesTokens(messages)
|
||||
};
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
@@ -169,7 +169,7 @@ A: ${chatBg}
|
||||
rawQuery: query,
|
||||
extensionQueries: [],
|
||||
model,
|
||||
charsLength: 0
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@@ -1,11 +1,7 @@
|
||||
import type { ChatItemType } from '@fastgpt/global/core/chat/type.d';
|
||||
import { ChatRoleEnum, IMG_BLOCK_KEY } from '@fastgpt/global/core/chat/constants';
|
||||
import { countMessagesTokens, countPromptTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import { adaptRole_Chat2Message } from '@fastgpt/global/core/chat/adapt';
|
||||
import type {
|
||||
ChatCompletionContentPart,
|
||||
ChatMessageItemType
|
||||
} from '@fastgpt/global/core/ai/type.d';
|
||||
import { countMessagesTokens } from '@fastgpt/global/common/string/tiktoken';
|
||||
import type { ChatCompletionContentPart } from '@fastgpt/global/core/ai/type.d';
|
||||
import axios from 'axios';
|
||||
|
||||
/* slice chat context by tokens */
|
||||
@@ -32,26 +28,34 @@ export function ChatContextFilter({
|
||||
const chatPrompts: ChatItemType[] = messages.slice(chatStartIndex);
|
||||
|
||||
// reduce token of systemPrompt
|
||||
maxTokens -= countMessagesTokens({
|
||||
messages: systemPrompts
|
||||
});
|
||||
maxTokens -= countMessagesTokens(systemPrompts);
|
||||
|
||||
// 根据 tokens 截断内容
|
||||
const chats: ChatItemType[] = [];
|
||||
// Save the last chat prompt(question)
|
||||
const question = chatPrompts.pop();
|
||||
if (!question) {
|
||||
return systemPrompts;
|
||||
}
|
||||
const chats: ChatItemType[] = [question];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = chatPrompts.length - 1; i >= 0; i--) {
|
||||
const item = chatPrompts[i];
|
||||
chats.unshift(item);
|
||||
// 从后往前截取对话内容, 每次需要截取2个
|
||||
while (1) {
|
||||
const assistant = chatPrompts.pop();
|
||||
const user = chatPrompts.pop();
|
||||
if (!assistant || !user) {
|
||||
break;
|
||||
}
|
||||
|
||||
const tokens = countPromptTokens(item.value, adaptRole_Chat2Message(item.obj));
|
||||
const tokens = countMessagesTokens([assistant, user]);
|
||||
maxTokens -= tokens;
|
||||
/* 整体 tokens 超出范围,截断 */
|
||||
if (maxTokens < 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
/* 整体 tokens 超出范围, system必须保留 */
|
||||
if (maxTokens <= 0) {
|
||||
if (chats.length > 1) {
|
||||
chats.shift();
|
||||
}
|
||||
chats.unshift(assistant);
|
||||
chats.unshift(user);
|
||||
|
||||
if (chatPrompts.length === 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -59,16 +63,6 @@ export function ChatContextFilter({
|
||||
return [...systemPrompts, ...chats];
|
||||
}
|
||||
|
||||
export const replaceValidChars = (str: string) => {
|
||||
const reg = /[\s\r\n]+/g;
|
||||
return str.replace(reg, '');
|
||||
};
|
||||
export const countMessagesChars = (messages: ChatItemType[]) => {
|
||||
return messages.reduce((sum, item) => sum + replaceValidChars(item.value).length, 0);
|
||||
};
|
||||
export const countGptMessagesChars = (messages: ChatMessageItemType[]) =>
|
||||
messages.reduce((sum, item) => sum + replaceValidChars(item.content).length, 0);
|
||||
|
||||
/**
|
||||
string to vision model. Follow the markdown code block rule for interception:
|
||||
|
||||
|
@@ -4,6 +4,7 @@ import { DatasetSchemaType } from '@fastgpt/global/core/dataset/type.d';
|
||||
import {
|
||||
DatasetStatusEnum,
|
||||
DatasetStatusMap,
|
||||
DatasetTypeEnum,
|
||||
DatasetTypeMap
|
||||
} from '@fastgpt/global/core/dataset/constants';
|
||||
import {
|
||||
@@ -39,7 +40,7 @@ const DatasetSchema = new Schema({
|
||||
type: String,
|
||||
enum: Object.keys(DatasetTypeMap),
|
||||
required: true,
|
||||
default: 'dataset'
|
||||
default: DatasetTypeEnum.dataset
|
||||
},
|
||||
status: {
|
||||
type: String,
|
||||
|
@@ -46,12 +46,16 @@ export async function pushDataListToTrainingQueue({
|
||||
} = await getCollectionWithDataset(collectionId);
|
||||
|
||||
const checkModelValid = async () => {
|
||||
if (trainingMode === TrainingModeEnum.chunk) {
|
||||
const vectorModelData = vectorModelList?.find((item) => item.model === vectorModel);
|
||||
if (!vectorModelData) {
|
||||
return Promise.reject(`File model ${vectorModel} is inValid`);
|
||||
}
|
||||
const agentModelData = datasetModelList?.find((item) => item.model === agentModel);
|
||||
if (!agentModelData) {
|
||||
return Promise.reject(`Vector model ${agentModel} is inValid`);
|
||||
}
|
||||
const vectorModelData = vectorModelList?.find((item) => item.model === vectorModel);
|
||||
if (!vectorModelData) {
|
||||
return Promise.reject(`File model ${vectorModel} is inValid`);
|
||||
}
|
||||
|
||||
if (trainingMode === TrainingModeEnum.chunk) {
|
||||
return {
|
||||
maxToken: vectorModelData.maxToken * 1.3,
|
||||
model: vectorModelData.model,
|
||||
@@ -59,17 +63,14 @@ export async function pushDataListToTrainingQueue({
|
||||
};
|
||||
}
|
||||
|
||||
if (trainingMode === TrainingModeEnum.qa) {
|
||||
const qaModelData = datasetModelList?.find((item) => item.model === agentModel);
|
||||
if (!qaModelData) {
|
||||
return Promise.reject(`Vector model ${agentModel} is inValid`);
|
||||
}
|
||||
if (trainingMode === TrainingModeEnum.qa || trainingMode === TrainingModeEnum.auto) {
|
||||
return {
|
||||
maxToken: qaModelData.maxContext * 0.8,
|
||||
model: qaModelData.model,
|
||||
maxToken: agentModelData.maxContext * 0.8,
|
||||
model: agentModelData.model,
|
||||
weight: 0
|
||||
};
|
||||
}
|
||||
|
||||
return Promise.reject(`Training mode "${trainingMode}" is inValid`);
|
||||
};
|
||||
|
||||
|
41
packages/service/core/dataset/training/utils.ts
Normal file
41
packages/service/core/dataset/training/utils.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import { DatasetTrainingSchemaType } from '@fastgpt/global/core/dataset/type';
|
||||
import { addLog } from '../../../common/system/log';
|
||||
import { getErrText } from '@fastgpt/global/common/error/utils';
|
||||
import { MongoDatasetTraining } from './schema';
|
||||
|
||||
export const checkInvalidChunkAndLock = async ({
|
||||
err,
|
||||
errText,
|
||||
data
|
||||
}: {
|
||||
err: any;
|
||||
errText: string;
|
||||
data: DatasetTrainingSchemaType;
|
||||
}) => {
|
||||
if (err?.response) {
|
||||
addLog.error(`openai error: ${errText}`, {
|
||||
status: err.response?.status,
|
||||
statusText: err.response?.statusText,
|
||||
data: err.response?.data
|
||||
});
|
||||
} else {
|
||||
addLog.error(getErrText(err, errText), err);
|
||||
}
|
||||
|
||||
if (
|
||||
err?.message === 'invalid message format' ||
|
||||
err?.type === 'invalid_request_error' ||
|
||||
err?.code === 500
|
||||
) {
|
||||
addLog.info('Lock training data');
|
||||
console.log(err);
|
||||
|
||||
try {
|
||||
await MongoDatasetTraining.findByIdAndUpdate(data._id, {
|
||||
lockTime: new Date('2998/5/5')
|
||||
});
|
||||
} catch (error) {}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
Reference in New Issue
Block a user