mirror of
https://github.com/labring/FastGPT.git
synced 2025-10-17 08:37:59 +00:00
fix: 训练后模型没选中
This commit is contained in:
3
src/api/response/chat.d.ts
vendored
3
src/api/response/chat.d.ts
vendored
@@ -8,7 +8,8 @@ export type InitChatResponse = {
|
|||||||
avatar: string;
|
avatar: string;
|
||||||
intro: string;
|
intro: string;
|
||||||
secret: ModelSchema.secret;
|
secret: ModelSchema.secret;
|
||||||
chatModel: ModelSchema.service.ChatModel; // 模型名
|
chatModel: ModelSchema.service.chatModel; // 对话模型名
|
||||||
|
modelName: ModelSchema.service.modelName; // 底层模型
|
||||||
history: ChatItemType[];
|
history: ChatItemType[];
|
||||||
isExpiredTime: boolean;
|
isExpiredTime: boolean;
|
||||||
};
|
};
|
||||||
|
@@ -51,11 +51,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts;
|
prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts;
|
||||||
|
|
||||||
// 格式化文本内容
|
// 格式化文本内容
|
||||||
const map = {
|
|
||||||
Human: 'Human',
|
|
||||||
AI: 'AI',
|
|
||||||
SYSTEM: 'SYSTEM'
|
|
||||||
};
|
|
||||||
const formatPrompts: string[] = filterPrompts.map((item: ChatItemType) => item.value);
|
const formatPrompts: string[] = filterPrompts.map((item: ChatItemType) => item.value);
|
||||||
// 如果有系统提示词,自动插入
|
// 如果有系统提示词,自动插入
|
||||||
if (model.systemPrompt) {
|
if (model.systemPrompt) {
|
||||||
@@ -85,7 +80,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
max_tokens: modelConstantsData.maxToken,
|
max_tokens: modelConstantsData.maxToken,
|
||||||
presence_penalty: 0, // 越大,越容易出现新内容
|
presence_penalty: 0, // 越大,越容易出现新内容
|
||||||
frequency_penalty: 0, // 越大,重复内容越少
|
frequency_penalty: 0, // 越大,重复内容越少
|
||||||
stop: ['。!?.!.', `</s>`]
|
stop: [`</s>`, '。!?.!.']
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
timeout: 40000,
|
timeout: 40000,
|
||||||
@@ -113,10 +108,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
try {
|
try {
|
||||||
const json = JSON.parse(data);
|
const json = JSON.parse(data);
|
||||||
const content: string = json?.choices?.[0].text || '';
|
const content: string = json?.choices?.[0].text || '';
|
||||||
|
console.log('content:', content);
|
||||||
if (!content || (responseContent === '' && content === '\n')) return;
|
if (!content || (responseContent === '' && content === '\n')) return;
|
||||||
|
|
||||||
responseContent += content;
|
responseContent += content;
|
||||||
// console.log('content:', content);
|
|
||||||
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
|
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
error;
|
error;
|
||||||
@@ -143,7 +138,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
// 只有使用平台的 key 才计费
|
// 只有使用平台的 key 才计费
|
||||||
!userApiKey &&
|
!userApiKey &&
|
||||||
pushChatBill({
|
pushChatBill({
|
||||||
modelName: model.service.modelName,
|
modelName: model.service.chatModel,
|
||||||
userId,
|
userId,
|
||||||
chatId,
|
chatId,
|
||||||
text: promptText + responseContent
|
text: promptText + responseContent
|
||||||
|
@@ -52,6 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
avatar: model.avatar,
|
avatar: model.avatar,
|
||||||
intro: model.intro,
|
intro: model.intro,
|
||||||
secret: model.security,
|
secret: model.security,
|
||||||
|
modelName: model.service.modelName,
|
||||||
chatModel: model.service.chatModel,
|
chatModel: model.service.chatModel,
|
||||||
history: chat.content
|
history: chat.content
|
||||||
}
|
}
|
||||||
|
@@ -1,15 +1,7 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { connectToDatabase, Model, Training } from '@/service/mongo';
|
import { connectToDatabase, Training } from '@/service/mongo';
|
||||||
import { getOpenAIApi } from '@/service/utils/chat';
|
import { authToken } from '@/service/utils/tools';
|
||||||
import formidable from 'formidable';
|
|
||||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
|
||||||
import { join } from 'path';
|
|
||||||
import fs from 'fs';
|
|
||||||
import type { ModelSchema } from '@/types/mongoSchema';
|
|
||||||
import type { OpenAIApi } from 'openai';
|
|
||||||
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
|
|
||||||
import { httpsAgent } from '@/service/utils/tools';
|
|
||||||
|
|
||||||
// 关闭next默认的bodyParser处理方式
|
// 关闭next默认的bodyParser处理方式
|
||||||
export const config = {
|
export const config = {
|
||||||
@@ -18,7 +10,7 @@ export const config = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* 上传文件,开始微调 */
|
/* 获取模型训练记录 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
try {
|
try {
|
||||||
const { authorization } = req.headers;
|
const { authorization } = req.headers;
|
||||||
@@ -30,7 +22,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
if (!modelId) {
|
if (!modelId) {
|
||||||
throw new Error('参数错误');
|
throw new Error('参数错误');
|
||||||
}
|
}
|
||||||
const userId = await authToken(authorization);
|
await authToken(authorization);
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
|
|
||||||
|
@@ -52,7 +52,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
// 删除训练文件
|
// 删除训练文件
|
||||||
openai.deleteFile(data.training_files[0].id, { httpsAgent });
|
openai.deleteFile(data.training_files[0].id, { httpsAgent });
|
||||||
|
|
||||||
// 更新模型
|
// 更新模型状态和模型内容
|
||||||
await Model.findByIdAndUpdate(modelId, {
|
await Model.findByIdAndUpdate(modelId, {
|
||||||
status: ModelStatusEnum.running,
|
status: ModelStatusEnum.running,
|
||||||
updateTime: new Date(),
|
updateTime: new Date(),
|
||||||
@@ -72,6 +72,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* 取消微调 */
|
||||||
if (data.status === OpenAiTuneStatusEnum.cancelled) {
|
if (data.status === OpenAiTuneStatusEnum.cancelled) {
|
||||||
// 删除训练文件
|
// 删除训练文件
|
||||||
openai.deleteFile(data.training_files[0].id, { httpsAgent });
|
openai.deleteFile(data.training_files[0].id, { httpsAgent });
|
||||||
@@ -87,11 +88,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
});
|
});
|
||||||
|
|
||||||
return jsonRes(res, {
|
return jsonRes(res, {
|
||||||
data: '模型微调取消'
|
data: '模型微调已取消'
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new Error('模型还在训练中');
|
jsonRes(res, {
|
||||||
|
data: '模型还在训练中'
|
||||||
|
});
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
code: 500,
|
code: 500,
|
||||||
|
@@ -30,6 +30,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
throw new Error('无权操作');
|
throw new Error('无权操作');
|
||||||
}
|
}
|
||||||
const { modelId } = req.query;
|
const { modelId } = req.query;
|
||||||
|
|
||||||
if (!modelId) {
|
if (!modelId) {
|
||||||
throw new Error('参数错误');
|
throw new Error('参数错误');
|
||||||
}
|
}
|
||||||
@@ -67,7 +68,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
});
|
});
|
||||||
const file = files.file;
|
const file = files.file;
|
||||||
|
|
||||||
// 上传文件
|
// 上传文件到 openai
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
const uploadRes = await openai.createFile(
|
const uploadRes = await openai.createFile(
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
|
@@ -62,6 +62,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
|
|||||||
intro: '',
|
intro: '',
|
||||||
secret: {},
|
secret: {},
|
||||||
chatModel: '',
|
chatModel: '',
|
||||||
|
modelName: '',
|
||||||
history: [],
|
history: [],
|
||||||
isExpiredTime: false
|
isExpiredTime: false
|
||||||
}); // 聊天框整体数据
|
}); // 聊天框整体数据
|
||||||
@@ -156,7 +157,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
|
|||||||
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
|
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
|
||||||
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
|
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
|
||||||
};
|
};
|
||||||
if (!urlMap[chatData.chatModel]) return Promise.reject('找不到模型');
|
|
||||||
|
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');
|
||||||
|
|
||||||
const prompt = {
|
const prompt = {
|
||||||
obj: prompts.obj,
|
obj: prompts.obj,
|
||||||
@@ -164,7 +166,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
|
|||||||
};
|
};
|
||||||
// 流请求,获取数据
|
// 流请求,获取数据
|
||||||
const res = await streamFetch({
|
const res = await streamFetch({
|
||||||
url: urlMap[chatData.chatModel],
|
url: urlMap[chatData.modelName],
|
||||||
data: {
|
data: {
|
||||||
prompt,
|
prompt,
|
||||||
chatId
|
chatId
|
||||||
@@ -217,7 +219,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
|
|||||||
})
|
})
|
||||||
}));
|
}));
|
||||||
},
|
},
|
||||||
[chatData.chatModel, chatId, toast]
|
[chatData.modelName, chatId, toast]
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@@ -108,9 +108,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
|||||||
|
|
||||||
// 重新获取模型
|
// 重新获取模型
|
||||||
loadModel();
|
loadModel();
|
||||||
} catch (err) {
|
} catch (err: any) {
|
||||||
toast({
|
toast({
|
||||||
title: typeof err === 'string' ? err : '文件格式错误',
|
title: err?.message || '上传文件失败',
|
||||||
status: 'error'
|
status: 'error'
|
||||||
});
|
});
|
||||||
console.log('error->', err);
|
console.log('error->', err);
|
||||||
@@ -126,7 +126,12 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
|||||||
setLoading(true);
|
setLoading(true);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await putModelTrainingStatus(model._id);
|
const res = await putModelTrainingStatus(model._id);
|
||||||
|
typeof res === 'string' &&
|
||||||
|
toast({
|
||||||
|
title: res,
|
||||||
|
status: 'info'
|
||||||
|
});
|
||||||
loadModel();
|
loadModel();
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
console.log('error->', error);
|
console.log('error->', error);
|
||||||
@@ -284,6 +289,9 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
{/* 提示 */}
|
{/* 提示 */}
|
||||||
<Box mt={3} py={3} color={'blackAlpha.600'}>
|
<Box mt={3} py={3} color={'blackAlpha.600'}>
|
||||||
|
<Box as={'li'} lineHeight={1.9}>
|
||||||
|
暂时需要使用自己的openai key
|
||||||
|
</Box>
|
||||||
<Box as={'li'} lineHeight={1.9}>
|
<Box as={'li'} lineHeight={1.9}>
|
||||||
可以使用
|
可以使用
|
||||||
<Box
|
<Box
|
||||||
|
@@ -50,7 +50,7 @@ const ModelSchema = new Schema({
|
|||||||
enum: ['openai']
|
enum: ['openai']
|
||||||
},
|
},
|
||||||
trainId: {
|
trainId: {
|
||||||
// 训练时需要的 ID
|
// 训练时需要的 ID, 不能训练的模型没有这个值。
|
||||||
type: String,
|
type: String,
|
||||||
required: false
|
required: false
|
||||||
},
|
},
|
||||||
|
Reference in New Issue
Block a user