mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-23 21:13:50 +00:00
feat: 模型数据管理
feat: 模型数据导入 feat: redis 向量入库 feat: 向量索引 feat: 文件导入模型 perf: 交互 perf: prompt
This commit is contained in:
@@ -4,3 +4,5 @@ MONGODB_URI=
|
||||
MY_MAIL=
|
||||
MAILE_CODE=
|
||||
TOKEN_KEY=
|
||||
OPENAIKEY=
|
||||
REDIS_URL=
|
@@ -41,6 +41,7 @@
|
||||
"react-hook-form": "^7.43.1",
|
||||
"react-markdown": "^8.0.5",
|
||||
"react-syntax-highlighter": "^15.5.0",
|
||||
"redis": "^4.6.5",
|
||||
"rehype-katex": "^6.0.2",
|
||||
"remark-gfm": "^3.0.1",
|
||||
"remark-math": "^5.1.1",
|
||||
|
95
pnpm-lock.yaml
generated
95
pnpm-lock.yaml
generated
@@ -47,6 +47,7 @@ specifiers:
|
||||
react-hook-form: ^7.43.1
|
||||
react-markdown: ^8.0.5
|
||||
react-syntax-highlighter: ^15.5.0
|
||||
redis: ^4.6.5
|
||||
rehype-katex: ^6.0.2
|
||||
remark-gfm: ^3.0.1
|
||||
remark-math: ^5.1.1
|
||||
@@ -87,6 +88,7 @@ dependencies:
|
||||
react-hook-form: registry.npmmirror.com/react-hook-form/7.43.1_react@18.2.0
|
||||
react-markdown: registry.npmmirror.com/react-markdown/8.0.5_pmekkgnqduwlme35zpnqhenc34
|
||||
react-syntax-highlighter: registry.npmmirror.com/react-syntax-highlighter/15.5.0_react@18.2.0
|
||||
redis: registry.npmmirror.com/redis/4.6.5
|
||||
rehype-katex: registry.npmmirror.com/rehype-katex/6.0.2
|
||||
remark-gfm: registry.npmmirror.com/remark-gfm/3.0.1
|
||||
remark-math: registry.npmmirror.com/remark-math/5.1.1
|
||||
@@ -4504,6 +4506,72 @@ packages:
|
||||
version: 2.11.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6:
|
||||
resolution: {integrity: sha512-HG2DFjYKbpNmVXsa0keLHp/3leGJz1mjh09f2RLGGLQZzSHpkmZWuwJbAvo3QcRY8p80m5+ZdXZdYOSBLlp7Cg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/bloom/-/bloom-1.2.0.tgz}
|
||||
id: registry.npmmirror.com/@redis/bloom/1.2.0
|
||||
name: '@redis/bloom'
|
||||
version: 1.2.0
|
||||
peerDependencies:
|
||||
'@redis/client': ^1.0.0
|
||||
dependencies:
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/client/1.5.6:
|
||||
resolution: {integrity: sha512-dFD1S6je+A47Lj22jN/upVU2fj4huR7S9APd7/ziUXsIXDL+11GPYti4Suv5y8FuXaN+0ZG4JF+y1houEJ7ToA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/client/-/client-1.5.6.tgz}
|
||||
name: '@redis/client'
|
||||
version: 1.5.6
|
||||
engines: {node: '>=14'}
|
||||
dependencies:
|
||||
cluster-key-slot: registry.npmmirror.com/cluster-key-slot/1.1.2
|
||||
generic-pool: registry.npmmirror.com/generic-pool/3.9.0
|
||||
yallist: registry.npmmirror.com/yallist/4.0.0
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6:
|
||||
resolution: {integrity: sha512-16yZWngxyXPd+MJxeSr0dqh2AIOi8j9yXKcKCwVaKDbH3HTuETpDVPcLujhFYVPtYrngSco31BUcSa9TH31Gqg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/graph/-/graph-1.1.0.tgz}
|
||||
id: registry.npmmirror.com/@redis/graph/1.1.0
|
||||
name: '@redis/graph'
|
||||
version: 1.1.0
|
||||
peerDependencies:
|
||||
'@redis/client': ^1.0.0
|
||||
dependencies:
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6:
|
||||
resolution: {integrity: sha512-LUZE2Gdrhg0Rx7AN+cZkb1e6HjoSKaeeW8rYnt89Tly13GBI5eP4CwDVr+MY8BAYfCg4/N15OUrtLoona9uSgw==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/json/-/json-1.0.4.tgz}
|
||||
id: registry.npmmirror.com/@redis/json/1.0.4
|
||||
name: '@redis/json'
|
||||
version: 1.0.4
|
||||
peerDependencies:
|
||||
'@redis/client': ^1.0.0
|
||||
dependencies:
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6:
|
||||
resolution: {integrity: sha512-/cMfstG/fOh/SsE+4/BQGeuH/JJloeWuH+qJzM8dbxuWvdWibWAOAHHCZTMPhV3xIlH4/cUEIA8OV5QnYpaVoA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/search/-/search-1.1.2.tgz}
|
||||
id: registry.npmmirror.com/@redis/search/1.1.2
|
||||
name: '@redis/search'
|
||||
version: 1.1.2
|
||||
peerDependencies:
|
||||
'@redis/client': ^1.0.0
|
||||
dependencies:
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6:
|
||||
resolution: {integrity: sha512-ThUIgo2U/g7cCuZavucQTQzA9g9JbDDY2f64u3AbAoz/8vE2lt2U37LamDUVChhaDA3IRT9R6VvJwqnUfTJzng==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/time-series/-/time-series-1.0.4.tgz}
|
||||
id: registry.npmmirror.com/@redis/time-series/1.0.4
|
||||
name: '@redis/time-series'
|
||||
version: 1.0.4
|
||||
peerDependencies:
|
||||
'@redis/client': ^1.0.0
|
||||
dependencies:
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/@rushstack/eslint-patch/1.2.0:
|
||||
resolution: {integrity: sha512-sXo/qW2/pAcmT43VoRKOJbDOfV3cYpq3szSVfIThQXNt+E4DfKj361vaAt3c88U5tPUxzEswam7GW48PJqtKAg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@rushstack/eslint-patch/-/eslint-patch-1.2.0.tgz}
|
||||
name: '@rushstack/eslint-patch'
|
||||
@@ -5562,6 +5630,13 @@ packages:
|
||||
version: 0.0.1
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/cluster-key-slot/1.1.2:
|
||||
resolution: {integrity: sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz}
|
||||
name: cluster-key-slot
|
||||
version: 1.1.2
|
||||
engines: {node: '>=0.10.0'}
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/color-convert/1.9.3:
|
||||
resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz}
|
||||
name: color-convert
|
||||
@@ -6799,6 +6874,13 @@ packages:
|
||||
version: 1.2.3
|
||||
dev: true
|
||||
|
||||
registry.npmmirror.com/generic-pool/3.9.0:
|
||||
resolution: {integrity: sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/generic-pool/-/generic-pool-3.9.0.tgz}
|
||||
name: generic-pool
|
||||
version: 3.9.0
|
||||
engines: {node: '>= 4'}
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/gensync/1.0.0-beta.2:
|
||||
resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gensync/-/gensync-1.0.0-beta.2.tgz}
|
||||
name: gensync
|
||||
@@ -9367,6 +9449,19 @@ packages:
|
||||
picomatch: registry.npmmirror.com/picomatch/2.3.1
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/redis/4.6.5:
|
||||
resolution: {integrity: sha512-O0OWA36gDQbswOdUuAhRL6mTZpHFN525HlgZgDaVNgCJIAZR3ya06NTESb0R+TUZ+BFaDpz6NnnVvoMx9meUFg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/redis/-/redis-4.6.5.tgz}
|
||||
name: redis
|
||||
version: 4.6.5
|
||||
dependencies:
|
||||
'@redis/bloom': registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6
|
||||
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
|
||||
'@redis/graph': registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6
|
||||
'@redis/json': registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6
|
||||
'@redis/search': registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6
|
||||
'@redis/time-series': registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6
|
||||
dev: false
|
||||
|
||||
registry.npmmirror.com/refractor/3.6.0:
|
||||
resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/refractor/-/refractor-3.6.0.tgz}
|
||||
name: refractor
|
||||
|
@@ -1,7 +1,10 @@
|
||||
import { GET, POST, DELETE, PUT } from './request';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema';
|
||||
import { ModelUpdateParams } from '@/types/model';
|
||||
import { TrainingItemType } from '../types/training';
|
||||
import { PagingData } from '@/types';
|
||||
import { RequestPaging } from '../types/index';
|
||||
import { Obj2Query } from '@/utils/tools';
|
||||
|
||||
export const getMyModels = () => GET<ModelSchema[]>('/model/list');
|
||||
|
||||
@@ -16,13 +19,35 @@ export const putModelById = (id: string, data: ModelUpdateParams) =>
|
||||
PUT(`/model/update?modelId=${id}`, data);
|
||||
|
||||
export const postTrainModel = (id: string, form: FormData) =>
|
||||
POST(`/model/train?modelId=${id}`, form, {
|
||||
POST(`/model/train/train?modelId=${id}`, form, {
|
||||
headers: {
|
||||
'content-type': 'multipart/form-data'
|
||||
}
|
||||
});
|
||||
|
||||
export const putModelTrainingStatus = (id: string) => PUT(`/model/putTrainStatus?modelId=${id}`);
|
||||
export const putModelTrainingStatus = (id: string) =>
|
||||
PUT(`/model/train/putTrainStatus?modelId=${id}`);
|
||||
|
||||
export const getModelTrainings = (id: string) =>
|
||||
GET<TrainingItemType[]>(`/model/getTrainings?modelId=${id}`);
|
||||
GET<TrainingItemType[]>(`/model/train/getTrainings?modelId=${id}`);
|
||||
|
||||
/* 模型 data */
|
||||
|
||||
type GetModelDataListProps = RequestPaging & {
|
||||
modelId: string;
|
||||
};
|
||||
export const getModelDataList = (props: GetModelDataListProps) =>
|
||||
GET(`/model/data/getModelData?${Obj2Query(props)}`);
|
||||
|
||||
export const postModelDataInput = (data: {
|
||||
modelId: string;
|
||||
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
|
||||
}) => POST(`/model/data/pushModelDataInput`, data);
|
||||
|
||||
export const postModelDataFileText = (modelId: string, text: string) =>
|
||||
POST(`/model/data/splitData`, { modelId, text });
|
||||
|
||||
export const putModelDataById = (data: { dataId: string; text: string }) =>
|
||||
PUT('/model/data/putModelData', data);
|
||||
export const delOneModelData = (dataId: string) =>
|
||||
DELETE(`/model/data/delModelDataById?dataId=${dataId}`);
|
||||
|
@@ -26,12 +26,12 @@ const navbarList = [
|
||||
link: '/model/list',
|
||||
activeLink: ['/model/list', '/model/detail']
|
||||
},
|
||||
{
|
||||
label: '数据',
|
||||
icon: 'icon-datafull',
|
||||
link: '/data/list',
|
||||
activeLink: ['/data/list', '/data/detail']
|
||||
},
|
||||
// {
|
||||
// label: '数据',
|
||||
// icon: 'icon-datafull',
|
||||
// link: '/data/list',
|
||||
// activeLink: ['/data/list', '/data/detail']
|
||||
// },
|
||||
{
|
||||
label: '账号',
|
||||
icon: 'icon-yonghu-yuan',
|
||||
|
@@ -1,11 +1,17 @@
|
||||
import type { ServiceName } from '@/types/mongoSchema';
|
||||
import { ModelSchema } from '../types/mongoSchema';
|
||||
import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema';
|
||||
|
||||
export enum ChatModelNameEnum {
|
||||
GPT35 = 'gpt-3.5-turbo',
|
||||
VECTOR_GPT = 'VECTOR_GPT',
|
||||
GPT3 = 'text-davinci-003'
|
||||
}
|
||||
|
||||
export const ChatModelNameMap = {
|
||||
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
|
||||
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
|
||||
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
|
||||
};
|
||||
|
||||
export type ModelConstantsData = {
|
||||
serviceCompany: `${ServiceName}`;
|
||||
name: string;
|
||||
@@ -29,6 +35,17 @@ export const modelList: ModelConstantsData[] = [
|
||||
trainedMaxToken: 2000,
|
||||
maxTemperature: 2,
|
||||
price: 3
|
||||
},
|
||||
{
|
||||
serviceCompany: 'openai',
|
||||
name: '知识库',
|
||||
model: ChatModelNameEnum.VECTOR_GPT,
|
||||
trainName: 'vector',
|
||||
maxToken: 4000,
|
||||
contextMaxToken: 7500,
|
||||
trainedMaxToken: 2000,
|
||||
maxTemperature: 1,
|
||||
price: 3
|
||||
}
|
||||
// {
|
||||
// serviceCompany: 'openai',
|
||||
@@ -76,6 +93,11 @@ export const formatModelStatus = {
|
||||
}
|
||||
};
|
||||
|
||||
export const ModelDataStatusMap = {
|
||||
0: '训练完成',
|
||||
1: '训练中'
|
||||
};
|
||||
|
||||
export const defaultModel: ModelSchema = {
|
||||
_id: '',
|
||||
userId: '',
|
||||
|
1
src/constants/redis.ts
Normal file
1
src/constants/redis.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const VecModelDataIndex = 'model:data';
|
@@ -8,7 +8,7 @@ export const usePaging = <T = any>({
|
||||
pageSize = 10,
|
||||
params = {}
|
||||
}: {
|
||||
api: (data: any) => Promise<PagingData<T>>;
|
||||
api: (data: any) => any;
|
||||
pageSize?: number;
|
||||
params?: Record<string, any>;
|
||||
}) => {
|
||||
@@ -30,7 +30,7 @@ export const usePaging = <T = any>({
|
||||
setRequesting(true);
|
||||
|
||||
try {
|
||||
const res = await api({
|
||||
const res: PagingData<T> = await api({
|
||||
pageNum: num,
|
||||
pageSize,
|
||||
...params
|
||||
@@ -75,6 +75,7 @@ export const usePaging = <T = any>({
|
||||
requesting,
|
||||
isLoadAll,
|
||||
nextPage,
|
||||
initRequesting
|
||||
initRequesting,
|
||||
setData
|
||||
};
|
||||
};
|
||||
|
@@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
const model: ModelSchema = chat.modelId;
|
||||
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型异常,请用 chatgpt 模型');
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
|
||||
// 读取对话内容
|
||||
|
241
src/pages/api/chat/vectorGpt.ts
Normal file
241
src/pages/api/chat/vectorGpt.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { connectToDatabase, ModelData } from '@/service/mongo';
|
||||
import { getOpenAIApi, authChat } from '@/service/utils/chat';
|
||||
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
import { ChatItemType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { PassThrough } from 'stream';
|
||||
import { modelList } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { connectRedis } from '@/service/redis';
|
||||
import { VecModelDataIndex } from '@/constants/redis';
|
||||
import { vectorToBuffer } from '@/utils/tools';
|
||||
|
||||
let vectorData = [
|
||||
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
|
||||
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
|
||||
].concat(new Array(1524).fill(0));
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let step = 0; // step=1时,表示开始了流响应
|
||||
const stream = new PassThrough();
|
||||
stream.on('error', () => {
|
||||
console.log('error: ', 'stream error');
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('close', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('error', () => {
|
||||
console.log('error: ', 'request error');
|
||||
stream.destroy();
|
||||
});
|
||||
|
||||
try {
|
||||
const { chatId, prompt } = req.body as {
|
||||
prompt: ChatItemType;
|
||||
chatId: string;
|
||||
};
|
||||
|
||||
const { authorization } = req.headers;
|
||||
if (!chatId || !prompt) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
const redis = await connectRedis();
|
||||
|
||||
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
|
||||
|
||||
const model: ModelSchema = chat.modelId;
|
||||
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
|
||||
if (!modelConstantsData) {
|
||||
throw new Error('模型加载异常');
|
||||
}
|
||||
|
||||
// 读取对话内容
|
||||
const prompts = [...chat.content, prompt];
|
||||
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const promptVector = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: 'text-embedding-ada-002',
|
||||
input: prompt.value
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
)
|
||||
.then((res) => res?.data?.data?.[0]?.embedding || []);
|
||||
|
||||
const binary = vectorToBuffer(promptVector);
|
||||
|
||||
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
|
||||
const redisData: any[] = await redis.sendCommand([
|
||||
'FT.SEARCH',
|
||||
`idx:${VecModelDataIndex}`,
|
||||
`@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`,
|
||||
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
|
||||
'RETURN',
|
||||
'1',
|
||||
'dataId',
|
||||
// 'SORTBY',
|
||||
// 'score',
|
||||
'PARAMS',
|
||||
'2',
|
||||
'blob',
|
||||
binary,
|
||||
'DIALECT',
|
||||
'2'
|
||||
]);
|
||||
|
||||
// 格式化响应值,获取去重后的id
|
||||
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
|
||||
.map((i) => {
|
||||
if (!redisData[i] || !redisData[i][1]) return '';
|
||||
return redisData[i][1];
|
||||
})
|
||||
.filter((item) => item);
|
||||
formatIds = Array.from(new Set(formatIds));
|
||||
|
||||
if (formatIds.length === 0) {
|
||||
throw new Error('对不起,我没有找到你的问题');
|
||||
}
|
||||
|
||||
// 从 mongo 中取出原文作为提示词
|
||||
const textArr = (
|
||||
await Promise.all(
|
||||
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
|
||||
if (!redisData[i] || !redisData[i][1]) return '';
|
||||
return ModelData.findById(redisData[i][1])
|
||||
.select('text')
|
||||
.then((res) => res?.text || '');
|
||||
})
|
||||
)
|
||||
).filter((item) => item);
|
||||
|
||||
// textArr 筛选,最多 3000 tokens
|
||||
const systemPrompt = systemPromptFilter(textArr, 2800);
|
||||
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: `请根据下面的知识回答问题: ${systemPrompt}`
|
||||
});
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
|
||||
|
||||
// 格式化文本内容成 chatgpt 格式
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
|
||||
(item: ChatItemType) => ({
|
||||
role: map[item.obj],
|
||||
content: item.value
|
||||
})
|
||||
);
|
||||
// console.log(formatPrompts);
|
||||
// 计算温度
|
||||
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
|
||||
|
||||
let startTime = Date.now();
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.service.chatModel,
|
||||
temperature: temperature,
|
||||
// max_tokens: modelConstantsData.maxToken,
|
||||
messages: formatPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: true
|
||||
},
|
||||
{
|
||||
timeout: 40000,
|
||||
responseType: 'stream',
|
||||
httpsAgent
|
||||
}
|
||||
);
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
// 创建响应流
|
||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
step = 1;
|
||||
|
||||
let responseContent = '';
|
||||
stream.pipe(res);
|
||||
|
||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||
if (event.type !== 'event') return;
|
||||
const data = event.data;
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
if (!content || (responseContent === '' && content === '\n')) return;
|
||||
|
||||
responseContent += content;
|
||||
// console.log('content:', content)
|
||||
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
};
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
try {
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (stream.destroyed) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
break;
|
||||
}
|
||||
const parser = createParser(onParse);
|
||||
parser.feed(decoder.decode(chunk));
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('pipe error', error);
|
||||
}
|
||||
// close stream
|
||||
!stream.destroyed && stream.push(null);
|
||||
stream.destroy();
|
||||
|
||||
const promptsContent = formatPrompts.map((item) => item.content).join('');
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: !userApiKey,
|
||||
modelName: model.service.modelName,
|
||||
userId,
|
||||
chatId,
|
||||
text: promptsContent + responseContent
|
||||
});
|
||||
// jsonRes(res);
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
// 直接结束流
|
||||
console.log('error,结束');
|
||||
stream.destroy();
|
||||
} else {
|
||||
res.status(500);
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
if (!DataRecord) {
|
||||
throw new Error('找不到数据集');
|
||||
}
|
||||
const replaceText = text.replace(/[\r\n\\n]+/g, ' ');
|
||||
const replaceText = text.replace(/[\\n]+/g, ' ');
|
||||
|
||||
// 文本拆分成 chunk
|
||||
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
|
||||
@@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
||||
chunks.forEach((chunk) => {
|
||||
splitText += chunk;
|
||||
const tokens = encode(splitText).length;
|
||||
if (tokens >= 980) {
|
||||
if (tokens >= 780) {
|
||||
dataItems.push({
|
||||
userId,
|
||||
dataId,
|
||||
|
@@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model';
|
||||
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
|
||||
import { Model } from '@/service/models/model';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
@@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 重名校验
|
||||
const authRepeatName = await Model.findOne({
|
||||
name,
|
||||
userId
|
||||
});
|
||||
if (authRepeatName) {
|
||||
throw new Error('模型名重复');
|
||||
}
|
||||
|
||||
// 上限校验
|
||||
const authCount = await Model.countDocuments({
|
||||
userId
|
||||
@@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
status: ModelStatusEnum.running,
|
||||
service: {
|
||||
company: modelItem.serviceCompany,
|
||||
trainId: modelItem.trainName,
|
||||
chatModel: modelItem.model,
|
||||
modelName: modelItem.model
|
||||
trainId: '',
|
||||
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
|
||||
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
|
||||
}
|
||||
});
|
||||
|
||||
|
@@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
let { modelId } = req.query as {
|
||||
modelId: string;
|
||||
let { dataId } = req.query as {
|
||||
dataId: string;
|
||||
};
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!modelId) {
|
||||
if (!dataId) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
await connectToDatabase();
|
||||
|
||||
await ModelData.deleteOne({
|
||||
modelId,
|
||||
_id: dataId,
|
||||
userId
|
||||
});
|
||||
|
@@ -14,6 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
pageNum: string;
|
||||
pageSize: string;
|
||||
};
|
||||
|
||||
const { authorization } = req.headers;
|
||||
|
||||
pageNum = +pageNum;
|
||||
@@ -41,7 +42,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
.limit(pageSize);
|
||||
|
||||
jsonRes(res, {
|
||||
data
|
||||
data: {
|
||||
pageNum,
|
||||
pageSize,
|
||||
data,
|
||||
total: await ModelData.countDocuments({
|
||||
modelId,
|
||||
userId
|
||||
})
|
||||
}
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
|
@@ -2,12 +2,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, ModelData, Model } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { ModelDataSchema } from '@/types/mongoSchema';
|
||||
import { generateVector } from '@/service/events/generateVector';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
const { modelId, data } = req.body as {
|
||||
modelId: string;
|
||||
data: { q: string; a: string }[];
|
||||
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
|
||||
};
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -43,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
}))
|
||||
);
|
||||
|
||||
generateVector(true);
|
||||
|
||||
jsonRes(res, {
|
||||
data: model
|
||||
});
|
57
src/pages/api/model/data/pushModelDataSelectData.ts
Normal file
57
src/pages/api/model/data/pushModelDataSelectData.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
|
||||
|
||||
if (!dataIds) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
await connectToDatabase();
|
||||
|
||||
const { authorization } = req.headers;
|
||||
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
const dataItems = (
|
||||
await Promise.all(
|
||||
dataIds.map((dataId) =>
|
||||
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
|
||||
{
|
||||
userId,
|
||||
dataId
|
||||
},
|
||||
'result text'
|
||||
)
|
||||
)
|
||||
)
|
||||
).flat();
|
||||
|
||||
// push data
|
||||
await ModelData.insertMany(
|
||||
dataItems.map((item) => ({
|
||||
modelId: modelId,
|
||||
userId,
|
||||
text: item.text,
|
||||
q: item.result.map((item) => ({
|
||||
id: nanoid(),
|
||||
text: item.q
|
||||
}))
|
||||
}))
|
||||
);
|
||||
|
||||
jsonRes(res, {
|
||||
data: dataItems
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
@@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools';
|
||||
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||
try {
|
||||
let { modelId, answer } = req.body as {
|
||||
modelId: string;
|
||||
answer: string;
|
||||
let { dataId, text } = req.body as {
|
||||
dataId: string;
|
||||
text: string;
|
||||
};
|
||||
const { authorization } = req.headers;
|
||||
|
||||
@@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
throw new Error('无权操作');
|
||||
}
|
||||
|
||||
if (!modelId) {
|
||||
if (!dataId) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
|
||||
@@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
|
||||
await ModelData.updateOne(
|
||||
{
|
||||
modelId,
|
||||
_id: dataId,
|
||||
userId
|
||||
},
|
||||
{
|
||||
a: answer
|
||||
text
|
||||
}
|
||||
);
|
||||
|
||||
|
67
src/pages/api/model/data/splitData.ts
Normal file
67
src/pages/api/model/data/splitData.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
|
||||
import { authToken } from '@/service/utils/tools';
|
||||
import { generateQA } from '@/service/events/generateQA';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
|
||||
/* 拆分数据成QA */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
try {
|
||||
const { text, modelId } = req.body as { text: string; modelId: string };
|
||||
if (!text || !modelId) {
|
||||
throw new Error('参数错误');
|
||||
}
|
||||
await connectToDatabase();
|
||||
|
||||
const { authorization } = req.headers;
|
||||
|
||||
const userId = await authToken(authorization);
|
||||
|
||||
// 验证是否是该用户的 model
|
||||
const model = await Model.findOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
if (!model) {
|
||||
throw new Error('无权操作该模型');
|
||||
}
|
||||
|
||||
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
|
||||
|
||||
// 文本拆分成 chunk
|
||||
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
|
||||
|
||||
const textList: string[] = [];
|
||||
let splitText = '';
|
||||
|
||||
chunks.forEach((chunk) => {
|
||||
splitText += chunk;
|
||||
const tokens = encode(splitText).length;
|
||||
if (tokens >= 980) {
|
||||
textList.push(splitText);
|
||||
splitText = '';
|
||||
}
|
||||
});
|
||||
|
||||
// 批量插入数据
|
||||
await SplitData.create({
|
||||
userId,
|
||||
modelId,
|
||||
rawText: text,
|
||||
textList
|
||||
});
|
||||
|
||||
// generateQA();
|
||||
|
||||
jsonRes(res, {
|
||||
data: { chunks, replaceText }
|
||||
});
|
||||
} catch (err) {
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
|
||||
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
|
||||
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
|
||||
import { TrainingStatusEnum } from '@/constants/model';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
@@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
|
||||
await connectToDatabase();
|
||||
|
||||
// 删除模型
|
||||
await Model.deleteOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
let requestQueue: any[] = [];
|
||||
// 删除对应的聊天
|
||||
await Chat.deleteMany({
|
||||
modelId
|
||||
});
|
||||
requestQueue.push(
|
||||
Chat.deleteMany({
|
||||
modelId
|
||||
})
|
||||
);
|
||||
|
||||
// 删除数据集
|
||||
requestQueue.push(
|
||||
ModelData.deleteMany({
|
||||
modelId
|
||||
})
|
||||
);
|
||||
|
||||
// 查看是否正在训练
|
||||
const training: TrainingItemType | null = await Training.findOne({
|
||||
@@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
}
|
||||
|
||||
// 删除对应训练记录
|
||||
await Training.deleteMany({
|
||||
modelId
|
||||
});
|
||||
requestQueue.push(
|
||||
Training.deleteMany({
|
||||
modelId
|
||||
})
|
||||
);
|
||||
|
||||
// 删除模型
|
||||
requestQueue.push(
|
||||
Model.deleteOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
})
|
||||
);
|
||||
await requestQueue;
|
||||
|
||||
jsonRes(res);
|
||||
} catch (err) {
|
||||
|
@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
||||
systemPrompt,
|
||||
intro,
|
||||
temperature,
|
||||
service,
|
||||
// service,
|
||||
security
|
||||
}
|
||||
);
|
||||
|
@@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
|
||||
async (prompts: ChatSiteItemType) => {
|
||||
const urlMap: Record<string, string> = {
|
||||
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
|
||||
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
|
||||
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
|
||||
};
|
||||
|
||||
|
@@ -184,7 +184,7 @@ const DataList = () => {
|
||||
>
|
||||
导入
|
||||
</Button>
|
||||
<Menu>
|
||||
{/* <Menu>
|
||||
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
|
||||
导出
|
||||
</MenuButton>
|
||||
@@ -200,7 +200,7 @@ const DataList = () => {
|
||||
</MenuItem>
|
||||
)}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</Menu> */}
|
||||
|
||||
<Button
|
||||
size={'sm'}
|
||||
|
141
src/pages/model/detail/components/InputDataModal.tsx
Normal file
141
src/pages/model/detail/components/InputDataModal.tsx
Normal file
@@ -0,0 +1,141 @@
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
IconButton,
|
||||
Flex,
|
||||
Button,
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
Input,
|
||||
Textarea
|
||||
} from '@chakra-ui/react';
|
||||
import { useForm, useFieldArray } from 'react-hook-form';
|
||||
import { postModelDataInput } from '@/api/model';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { DeleteIcon } from '@chakra-ui/icons';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
||||
|
||||
type FormData = { text: string; q: { val: string }[] };
|
||||
|
||||
const InputDataModal = ({
|
||||
onClose,
|
||||
onSuccess,
|
||||
modelId
|
||||
}: {
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
modelId: string;
|
||||
}) => {
|
||||
const [importing, setImporting] = useState(false);
|
||||
const { toast } = useToast();
|
||||
|
||||
const { register, handleSubmit, control } = useForm<FormData>({
|
||||
defaultValues: {
|
||||
text: '',
|
||||
q: [{ val: '' }]
|
||||
}
|
||||
});
|
||||
const {
|
||||
fields: inputQ,
|
||||
append: appendQ,
|
||||
remove: removeQ
|
||||
} = useFieldArray({
|
||||
control,
|
||||
name: 'q'
|
||||
});
|
||||
|
||||
const sureImportData = useCallback(
|
||||
async (e: FormData) => {
|
||||
setImporting(true);
|
||||
|
||||
try {
|
||||
await postModelDataInput({
|
||||
modelId: modelId,
|
||||
data: [
|
||||
{
|
||||
text: e.text,
|
||||
q: e.q.map((item) => ({
|
||||
id: nanoid(),
|
||||
text: item.val
|
||||
}))
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
toast({
|
||||
title: '导入数据成功,需要一段时间训练',
|
||||
status: 'success'
|
||||
});
|
||||
onClose();
|
||||
onSuccess();
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
}
|
||||
setImporting(false);
|
||||
},
|
||||
[modelId, onClose, onSuccess, toast]
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal isOpen={true} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
|
||||
<ModalHeader>手动导入</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<Box px={6} pb={2} overflowY={'auto'}>
|
||||
<Box mb={2}>知识点:</Box>
|
||||
<Textarea
|
||||
mb={4}
|
||||
placeholder="知识点"
|
||||
rows={3}
|
||||
maxH={'200px'}
|
||||
{...register(`text`, {
|
||||
required: '知识点'
|
||||
})}
|
||||
/>
|
||||
{inputQ.map((item, index) => (
|
||||
<Box key={item.id} mb={5}>
|
||||
<Box mb={2}>问法{index + 1}:</Box>
|
||||
<Flex>
|
||||
<Input
|
||||
placeholder="问法"
|
||||
{...register(`q.${index}.val`, {
|
||||
required: '问法不能为空'
|
||||
})}
|
||||
></Input>
|
||||
{inputQ.length > 1 && (
|
||||
<IconButton
|
||||
icon={<DeleteIcon />}
|
||||
aria-label={'delete'}
|
||||
colorScheme={'gray'}
|
||||
variant={'unstyled'}
|
||||
onClick={() => removeQ(index)}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
))}
|
||||
</Box>
|
||||
|
||||
<Flex px={6} pt={2} pb={4}>
|
||||
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
|
||||
增加问法
|
||||
</Button>
|
||||
<Box flex={1}></Box>
|
||||
<Button variant={'outline'} mr={3} onClick={onClose}>
|
||||
取消
|
||||
</Button>
|
||||
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
|
||||
确认导入
|
||||
</Button>
|
||||
</Flex>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default InputDataModal;
|
202
src/pages/model/detail/components/ModelDataCard.tsx
Normal file
202
src/pages/model/detail/components/ModelDataCard.tsx
Normal file
@@ -0,0 +1,202 @@
|
||||
import React, { useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
TableContainer,
|
||||
Table,
|
||||
Thead,
|
||||
Tbody,
|
||||
Tr,
|
||||
Th,
|
||||
Td,
|
||||
IconButton,
|
||||
Flex,
|
||||
Button,
|
||||
useDisclosure,
|
||||
Textarea,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem
|
||||
} from '@chakra-ui/react';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { ModelDataSchema } from '@/types/mongoSchema';
|
||||
import { ModelDataStatusMap } from '@/constants/model';
|
||||
import { usePaging } from '@/hooks/usePaging';
|
||||
import ScrollData from '@/components/ScrollData';
|
||||
import { getModelDataList, delOneModelData, putModelDataById } from '@/api/model';
|
||||
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { useLoading } from '@/hooks/useLoading';
|
||||
import dynamic from 'next/dynamic';
|
||||
|
||||
const InputModel = dynamic(() => import('./InputDataModal'));
|
||||
const SelectModel = dynamic(() => import('./SelectFileModal'));
|
||||
|
||||
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
|
||||
const { toast } = useToast();
|
||||
const { Loading } = useLoading();
|
||||
|
||||
const {
|
||||
nextPage,
|
||||
isLoadAll,
|
||||
requesting,
|
||||
data: modelDataList,
|
||||
total,
|
||||
setData,
|
||||
getData
|
||||
} = usePaging<ModelDataSchema>({
|
||||
api: getModelDataList,
|
||||
pageSize: 20,
|
||||
params: {
|
||||
modelId: model._id
|
||||
}
|
||||
});
|
||||
|
||||
const updateAnswer = useCallback(
|
||||
async (dataId: string, text: string) => {
|
||||
await putModelDataById({
|
||||
dataId,
|
||||
text
|
||||
});
|
||||
toast({
|
||||
title: '修改回答成功',
|
||||
status: 'success'
|
||||
});
|
||||
},
|
||||
[toast]
|
||||
);
|
||||
|
||||
const {
|
||||
isOpen: isOpenInputModal,
|
||||
onOpen: onOpenInputModal,
|
||||
onClose: onCloseInputModal
|
||||
} = useDisclosure();
|
||||
const {
|
||||
isOpen: isOpenSelectModal,
|
||||
onOpen: onOpenSelectModal,
|
||||
onClose: onCloseSelectModal
|
||||
} = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex>
|
||||
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
|
||||
模型数据: {total}组{' '}
|
||||
<Box as={'span'} fontSize={'sm'}>
|
||||
(测试版本)
|
||||
</Box>
|
||||
</Box>
|
||||
<IconButton
|
||||
icon={<RepeatIcon />}
|
||||
aria-label={'refresh'}
|
||||
variant={'outline'}
|
||||
mr={4}
|
||||
onClick={() => getData(1, true)}
|
||||
/>
|
||||
<Menu>
|
||||
<MenuButton as={Button}>导入</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem onClick={onOpenInputModal}>手动输入</MenuItem>
|
||||
<MenuItem onClick={onOpenSelectModal}>文件导入</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</Flex>
|
||||
<ScrollData
|
||||
h={'100%'}
|
||||
px={6}
|
||||
mt={3}
|
||||
isLoadAll={isLoadAll}
|
||||
requesting={requesting}
|
||||
nextPage={nextPage}
|
||||
position={'relative'}
|
||||
>
|
||||
<TableContainer mt={4}>
|
||||
<Table variant={'simple'}>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Question</Th>
|
||||
<Th>Text</Th>
|
||||
<Th>Status</Th>
|
||||
<Th></Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{modelDataList.map((item) => (
|
||||
<Tr key={item._id}>
|
||||
<Td w={'350px'}>
|
||||
{item.q.map((item, i) => (
|
||||
<Box
|
||||
key={item.id}
|
||||
fontSize={'xs'}
|
||||
w={'100%'}
|
||||
whiteSpace={'pre-wrap'}
|
||||
_notLast={{ mb: 1 }}
|
||||
>
|
||||
Q{i + 1}:{' '}
|
||||
<Box as={'span'} userSelect={'all'}>
|
||||
{item.text}
|
||||
</Box>
|
||||
</Box>
|
||||
))}
|
||||
</Td>
|
||||
<Td minW={'200px'}>
|
||||
<Textarea
|
||||
w={'100%'}
|
||||
h={'100%'}
|
||||
defaultValue={item.text}
|
||||
fontSize={'xs'}
|
||||
resize={'both'}
|
||||
onBlur={(e) => {
|
||||
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
|
||||
if (oldVal !== e.target.value) {
|
||||
updateAnswer(item._id, e.target.value);
|
||||
setData((state) =>
|
||||
state.map((data) => ({
|
||||
...data,
|
||||
text: data._id === item._id ? e.target.value : data.text
|
||||
}))
|
||||
);
|
||||
}
|
||||
}}
|
||||
></Textarea>
|
||||
</Td>
|
||||
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
|
||||
<Td>
|
||||
<IconButton
|
||||
icon={<DeleteIcon />}
|
||||
variant={'outline'}
|
||||
colorScheme={'gray'}
|
||||
aria-label={'delete'}
|
||||
size={'sm'}
|
||||
onClick={async () => {
|
||||
delOneModelData(item._id);
|
||||
setData((state) => state.filter((data) => data._id !== item._id));
|
||||
}}
|
||||
/>
|
||||
</Td>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
<Loading loading={requesting} fixed={false} />
|
||||
</ScrollData>
|
||||
{isOpenInputModal && (
|
||||
<InputModel
|
||||
modelId={model._id}
|
||||
onClose={onCloseInputModal}
|
||||
onSuccess={() => getData(1, true)}
|
||||
/>
|
||||
)}
|
||||
{isOpenSelectModal && (
|
||||
<SelectModel
|
||||
modelId={model._id}
|
||||
onClose={onCloseSelectModal}
|
||||
onSuccess={() => getData(1, true)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelDataCard;
|
@@ -11,13 +11,28 @@ import {
|
||||
SliderFilledTrack,
|
||||
SliderThumb,
|
||||
SliderMark,
|
||||
Tooltip
|
||||
Tooltip,
|
||||
Button
|
||||
} from '@chakra-ui/react';
|
||||
import { QuestionOutlineIcon } from '@chakra-ui/icons';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { UseFormReturn } from 'react-hook-form';
|
||||
import { modelList } from '@/constants/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { useConfirm } from '@/hooks/useConfirm';
|
||||
|
||||
const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> }) => {
|
||||
const ModelEditForm = ({
|
||||
formHooks,
|
||||
canTrain,
|
||||
handleDelModel
|
||||
}: {
|
||||
formHooks: UseFormReturn<ModelSchema>;
|
||||
canTrain: boolean;
|
||||
handleDelModel: () => void;
|
||||
}) => {
|
||||
const { openConfirm, ConfirmChild } = useConfirm({
|
||||
content: '确认删除该模型?'
|
||||
});
|
||||
const { register, setValue, getValues } = formHooks;
|
||||
const [refresh, setRefresh] = useState(false);
|
||||
|
||||
@@ -29,7 +44,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
|
||||
</Flex>
|
||||
<FormControl mt={4}>
|
||||
<Flex alignItems={'center'}>
|
||||
<Box flex={'0 0 50px'} w={0}>
|
||||
<Box flex={'0 0 80px'} w={0}>
|
||||
名称:
|
||||
</Box>
|
||||
<Input
|
||||
@@ -39,7 +54,36 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
|
||||
></Input>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<FormControl mt={4}>
|
||||
<Flex alignItems={'center'} mt={4}>
|
||||
<Box flex={'0 0 80px'} w={0}>
|
||||
底层模型:
|
||||
</Box>
|
||||
<Box>{getValues('service.modelName')}</Box>
|
||||
</Flex>
|
||||
<Flex alignItems={'center'} mt={4}>
|
||||
<Box flex={'0 0 80px'} w={0}>
|
||||
价格:
|
||||
</Box>
|
||||
<Box>
|
||||
{formatPrice(
|
||||
modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0,
|
||||
1000
|
||||
)}
|
||||
元/1K tokens(包括上下文和回答)
|
||||
</Box>
|
||||
</Flex>
|
||||
<Flex mt={5} alignItems={'center'}>
|
||||
<Box flex={'0 0 80px'}>删除:</Box>
|
||||
<Button
|
||||
colorScheme={'gray'}
|
||||
variant={'outline'}
|
||||
size={'sm'}
|
||||
onClick={openConfirm(handleDelModel)}
|
||||
>
|
||||
删除模型
|
||||
</Button>
|
||||
</Flex>
|
||||
{/* <FormControl mt={4}>
|
||||
<Box mb={1}>介绍:</Box>
|
||||
<Textarea
|
||||
rows={5}
|
||||
@@ -47,7 +91,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
|
||||
{...register('intro')}
|
||||
placeholder={'模型的介绍,仅做展示,不影响模型的效果'}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormControl> */}
|
||||
</Card>
|
||||
<Card p={4}>
|
||||
<Box fontWeight={'bold'}>模型效果</Box>
|
||||
@@ -94,15 +138,24 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Box mt={4}>
|
||||
<Box mb={1}>系统提示词</Box>
|
||||
<Textarea
|
||||
rows={6}
|
||||
maxLength={-1}
|
||||
{...register('systemPrompt')}
|
||||
placeholder={
|
||||
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意,改功能会影响对话的整体朝向!'
|
||||
}
|
||||
/>
|
||||
{canTrain ? (
|
||||
<Box fontWeight={'bold'}>
|
||||
训练的模型会自动根据知识库内容回答,无法设置系统prompt。注意:
|
||||
使用该模型,在对话时需要消耗等多的 tokens
|
||||
</Box>
|
||||
) : (
|
||||
<>
|
||||
<Box mb={1}>系统提示词</Box>
|
||||
<Textarea
|
||||
rows={6}
|
||||
maxLength={-1}
|
||||
{...register('systemPrompt')}
|
||||
placeholder={
|
||||
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意,改功能会影响对话的整体朝向!'
|
||||
}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</Box>
|
||||
</Card>
|
||||
{/* <Card p={4}>
|
||||
@@ -202,6 +255,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Card> */}
|
||||
<ConfirmChild />
|
||||
</>
|
||||
);
|
||||
};
|
155
src/pages/model/detail/components/SelectFileModal.tsx
Normal file
155
src/pages/model/detail/components/SelectFileModal.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Button,
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
ModalBody
|
||||
} from '@chakra-ui/react';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { useSelectFile } from '@/hooks/useSelectFile';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import { useConfirm } from '@/hooks/useConfirm';
|
||||
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/tools';
|
||||
import { useMutation } from '@tanstack/react-query';
|
||||
import { postModelDataFileText } from '@/api/model';
|
||||
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
||||
|
||||
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
|
||||
|
||||
const SelectFileModal = ({
|
||||
onClose,
|
||||
onSuccess,
|
||||
modelId
|
||||
}: {
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
modelId: string;
|
||||
}) => {
|
||||
const [selecting, setSelecting] = useState(false);
|
||||
const { toast } = useToast();
|
||||
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
|
||||
const [fileText, setFileText] = useState('');
|
||||
const { openConfirm, ConfirmChild } = useConfirm({
|
||||
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!'
|
||||
});
|
||||
|
||||
const onSelectFile = useCallback(
|
||||
async (e: File[]) => {
|
||||
setSelecting(true);
|
||||
try {
|
||||
const fileTexts = (
|
||||
await Promise.all(
|
||||
e.map((file) => {
|
||||
// @ts-ignore
|
||||
const extension = file?.name?.split('.').pop().toLowerCase();
|
||||
switch (extension) {
|
||||
case 'txt':
|
||||
case 'md':
|
||||
return readTxtContent(file);
|
||||
case 'pdf':
|
||||
return readPdfContent(file);
|
||||
case 'doc':
|
||||
case 'docx':
|
||||
return readDocContent(file);
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
})
|
||||
)
|
||||
)
|
||||
.join('\n')
|
||||
.replace(/\n+/g, '\n');
|
||||
setFileText(fileTexts);
|
||||
console.log(encode(fileTexts));
|
||||
} catch (error: any) {
|
||||
console.log(error);
|
||||
toast({
|
||||
title: typeof error === 'string' ? error : '解析文件失败',
|
||||
status: 'error'
|
||||
});
|
||||
}
|
||||
setSelecting(false);
|
||||
},
|
||||
[setSelecting, toast]
|
||||
);
|
||||
|
||||
const { mutate, isLoading } = useMutation({
|
||||
mutationFn: async () => {
|
||||
if (!fileText) return;
|
||||
await postModelDataFileText(modelId, fileText);
|
||||
toast({
|
||||
title: '导入数据成功,需要一段拆解和训练',
|
||||
status: 'success'
|
||||
});
|
||||
onClose();
|
||||
onSuccess();
|
||||
},
|
||||
onError() {
|
||||
toast({
|
||||
title: '导入文件失败',
|
||||
status: 'error'
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return (
|
||||
<Modal isOpen={true} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
|
||||
<ModalHeader>文件导入</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
|
||||
<ModalBody>
|
||||
<Flex
|
||||
flexDirection={'column'}
|
||||
p={2}
|
||||
h={'100%'}
|
||||
alignItems={'center'}
|
||||
justifyContent={'center'}
|
||||
fontSize={'sm'}
|
||||
>
|
||||
<Button isLoading={selecting} onClick={onOpen}>
|
||||
选择文件
|
||||
</Button>
|
||||
<Box mt={2}>支持 {fileExtension} 文件. 会先对文本进行拆分,需要时间较长。</Box>
|
||||
<Box mt={2}>
|
||||
一共 {fileText.length} 个字,{encode(fileText).length} 个tokens
|
||||
</Box>
|
||||
<Box
|
||||
h={'300px'}
|
||||
w={'100%'}
|
||||
overflow={'auto'}
|
||||
p={2}
|
||||
backgroundColor={'blackAlpha.50'}
|
||||
whiteSpace={'pre'}
|
||||
fontSize={'xs'}
|
||||
>
|
||||
{fileText}
|
||||
</Box>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
|
||||
<Flex px={6} pt={2} pb={4}>
|
||||
<Box flex={1}></Box>
|
||||
<Button variant={'outline'} mr={3} onClick={onClose}>
|
||||
取消
|
||||
</Button>
|
||||
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
|
||||
确认导入
|
||||
</Button>
|
||||
</Flex>
|
||||
</ModalContent>
|
||||
<ConfirmChild />
|
||||
<File onSelect={onSelectFile} />
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default SelectFileModal;
|
@@ -1,37 +1,27 @@
|
||||
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
|
||||
import { useRouter } from 'next/router';
|
||||
import {
|
||||
getModelById,
|
||||
delModelById,
|
||||
postTrainModel,
|
||||
putModelTrainingStatus,
|
||||
putModelById
|
||||
} from '@/api/model';
|
||||
import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
|
||||
import { getChatSiteId } from '@/api/chat';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { useConfirm } from '@/hooks/useConfirm';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/constants/model';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
import { useScreen } from '@/hooks/useScreen';
|
||||
import ModelEditForm from './components/ModelEditForm';
|
||||
import Icon from '@/components/Iconfont';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import dynamic from 'next/dynamic';
|
||||
|
||||
const Training = dynamic(() => import('./components/Training'));
|
||||
const ModelDataCard = dynamic(() => import('./components/ModelDataCard'));
|
||||
|
||||
const ModelDetail = ({ modelId }: { modelId: string }) => {
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const { isPc, media } = useScreen();
|
||||
const { setLoading } = useGlobalStore();
|
||||
const { openConfirm, ConfirmChild } = useConfirm({
|
||||
content: '确认删除该模型?'
|
||||
});
|
||||
const SelectFileDom = useRef<HTMLInputElement>(null);
|
||||
|
||||
// const SelectFileDom = useRef<HTMLInputElement>(null);
|
||||
const [model, setModel] = useState<ModelSchema>(defaultModel);
|
||||
const formHooks = useForm<ModelSchema>({
|
||||
defaultValues: model
|
||||
@@ -39,7 +29,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
||||
|
||||
const canTrain = useMemo(() => {
|
||||
const openai = modelList.find((item) => item.model === model?.service.modelName);
|
||||
return openai && openai.trainName;
|
||||
return !!(openai && openai.trainName);
|
||||
}, [model]);
|
||||
|
||||
/* 加载模型数据 */
|
||||
@@ -91,34 +81,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
||||
}, [setLoading, model, router]);
|
||||
|
||||
/* 上传数据集,触发微调 */
|
||||
const startTraining = useCallback(
|
||||
async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (!modelId || !e.target.files || e.target.files?.length === 0) return;
|
||||
setLoading(true);
|
||||
try {
|
||||
const file = e.target.files[0];
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
await postTrainModel(modelId, formData);
|
||||
// const startTraining = useCallback(
|
||||
// async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
// if (!modelId || !e.target.files || e.target.files?.length === 0) return;
|
||||
// setLoading(true);
|
||||
// try {
|
||||
// const file = e.target.files[0];
|
||||
// const formData = new FormData();
|
||||
// formData.append('file', file);
|
||||
// await postTrainModel(modelId, formData);
|
||||
|
||||
toast({
|
||||
title: '开始训练...',
|
||||
status: 'success'
|
||||
});
|
||||
// toast({
|
||||
// title: '开始训练...',
|
||||
// status: 'success'
|
||||
// });
|
||||
|
||||
// 重新获取模型
|
||||
loadModel();
|
||||
} catch (err: any) {
|
||||
toast({
|
||||
title: err?.message || '上传文件失败',
|
||||
status: 'error'
|
||||
});
|
||||
console.log('error->', err);
|
||||
}
|
||||
setLoading(false);
|
||||
},
|
||||
[setLoading, loadModel, modelId, toast]
|
||||
);
|
||||
// // 重新获取模型
|
||||
// loadModel();
|
||||
// } catch (err: any) {
|
||||
// toast({
|
||||
// title: err?.message || '上传文件失败',
|
||||
// status: 'error'
|
||||
// });
|
||||
// console.log('error->', err);
|
||||
// }
|
||||
// setLoading(false);
|
||||
// },
|
||||
// [setLoading, loadModel, modelId, toast]
|
||||
// );
|
||||
|
||||
/* 点击更新模型状态 */
|
||||
const handleClickUpdateStatus = useCallback(async () => {
|
||||
@@ -250,87 +240,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
|
||||
)}
|
||||
</Card>
|
||||
<Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}>
|
||||
<ModelEditForm formHooks={formHooks} />
|
||||
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} canTrain={canTrain} />
|
||||
|
||||
{canTrain && (
|
||||
{/* {canTrain && (
|
||||
<Card p={4}>
|
||||
<Training model={model} />
|
||||
</Card>
|
||||
)} */}
|
||||
{canTrain && model._id && (
|
||||
<Card
|
||||
p={4}
|
||||
height={'700px'}
|
||||
{...media(
|
||||
{
|
||||
gridColumnStart: 1,
|
||||
gridColumnEnd: 3
|
||||
},
|
||||
{}
|
||||
)}
|
||||
>
|
||||
<ModelDataCard model={model} />
|
||||
</Card>
|
||||
)}
|
||||
|
||||
<Card p={4}>
|
||||
<Box fontWeight={'bold'} fontSize={'lg'}>
|
||||
神奇操作
|
||||
</Box>
|
||||
<Flex mt={5} alignItems={'center'}>
|
||||
<Box flex={'0 0 80px'}>模型微调:</Box>
|
||||
<Button
|
||||
size={'sm'}
|
||||
onClick={() => {
|
||||
SelectFileDom.current?.click();
|
||||
}}
|
||||
title={!canTrain ? '模型不支持微调' : ''}
|
||||
isDisabled={!canTrain}
|
||||
>
|
||||
上传数据集
|
||||
</Button>
|
||||
<Flex
|
||||
as={'a'}
|
||||
href="/TrainingTemplate.jsonl"
|
||||
download
|
||||
ml={5}
|
||||
cursor={'pointer'}
|
||||
alignItems={'center'}
|
||||
color={'blue.500'}
|
||||
>
|
||||
<Icon name={'icon-yunxiazai'} color={'#3182ce'} />
|
||||
下载模板
|
||||
</Flex>
|
||||
</Flex>
|
||||
{/* 提示 */}
|
||||
<Box mt={3} py={3} color={'blackAlpha.600'}>
|
||||
<Box as={'li'} lineHeight={1.9}>
|
||||
暂时需要使用自己的openai key
|
||||
</Box>
|
||||
<Box as={'li'} lineHeight={1.9}>
|
||||
可以使用
|
||||
<Box
|
||||
as={'span'}
|
||||
fontWeight={'bold'}
|
||||
textDecoration={'underline'}
|
||||
color={'blackAlpha.800'}
|
||||
mx={2}
|
||||
cursor={'pointer'}
|
||||
onClick={() => router.push('/data/list')}
|
||||
>
|
||||
数据拆分
|
||||
</Box>
|
||||
功能,从任意文本中提取数据集。
|
||||
</Box>
|
||||
<Box as={'li'} lineHeight={1.9}>
|
||||
每行包括一个 prompt 和一个 completion
|
||||
</Box>
|
||||
<Box as={'li'} lineHeight={1.9}>
|
||||
prompt 必须以 {'</s>'} 结尾
|
||||
</Box>
|
||||
<Box as={'li'} lineHeight={1.9}>
|
||||
completion 开头必须有一个空格,必须以 {'</s>'} 结尾
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex mt={5} alignItems={'center'}>
|
||||
<Box flex={'0 0 80px'}>删除模型:</Box>
|
||||
<Button colorScheme={'red'} size={'sm'} onClick={openConfirm(handleDelModel)}>
|
||||
删除模型
|
||||
</Button>
|
||||
</Flex>
|
||||
</Card>
|
||||
</Grid>
|
||||
|
||||
{/* 文件选择 */}
|
||||
<Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
|
||||
{/* <Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
|
||||
<input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} />
|
||||
</Box>
|
||||
<ConfirmChild />
|
||||
</Box> */}
|
||||
</>
|
||||
);
|
||||
};
|
@@ -1,29 +1,26 @@
|
||||
import { DataItem } from '@/service/mongo';
|
||||
import { SplitData, ModelData } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
|
||||
import type { ChatCompletionRequestMessage } from 'openai';
|
||||
import { DataItemSchema } from '@/types/mongoSchema';
|
||||
import { ChatModelNameEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
import { generateVector } from './generateVector';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
||||
|
||||
export async function generateQA(next = false): Promise<any> {
|
||||
if (process.env.NODE_ENV === 'development') return;
|
||||
|
||||
if (global.generatingQA && !next) return;
|
||||
global.generatingQA = true;
|
||||
|
||||
const systemPrompt: ChatCompletionRequestMessage = {
|
||||
role: 'system',
|
||||
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,请按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
|
||||
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
|
||||
};
|
||||
let dataItem: DataItemSchema | null = null;
|
||||
|
||||
try {
|
||||
// 找出一个需要生成的 dataItem
|
||||
dataItem = await DataItem.findOne({
|
||||
status: { $ne: 0 },
|
||||
times: { $gt: 0 },
|
||||
type: 'QA'
|
||||
const dataItem = await SplitData.findOne({
|
||||
textList: { $exists: true, $ne: [] }
|
||||
});
|
||||
|
||||
if (!dataItem) {
|
||||
@@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
|
||||
return;
|
||||
}
|
||||
|
||||
// 更新状态为生成中
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: 2
|
||||
});
|
||||
// 弹出文本
|
||||
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
|
||||
|
||||
const text = dataItem.textList[dataItem.textList.length - 1];
|
||||
if (!text) {
|
||||
throw new Error('无文本');
|
||||
}
|
||||
|
||||
// 获取 openapi Key
|
||||
let userApiKey, systemKey;
|
||||
@@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
|
||||
userApiKey = key.userApiKey;
|
||||
systemKey = key.systemKey;
|
||||
} catch (error) {
|
||||
// 余额不够了, 把用户所有记录改成闲置
|
||||
await DataItem.updateMany({
|
||||
userId: dataItem.userId,
|
||||
status: 0
|
||||
// 余额不够了, 清空该记录
|
||||
await SplitData.findByIdAndUpdate(dataItem._id, {
|
||||
textList: [],
|
||||
errorText: '余额不足,生成数据集任务终止'
|
||||
});
|
||||
throw new Error('获取 openai key 失败');
|
||||
}
|
||||
@@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
|
||||
// 获取 openai 请求实例
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
// 请求 chatgpt 获取回答
|
||||
const response = await Promise.allSettled(
|
||||
[0.2, 0.8].map(
|
||||
(temperature) =>
|
||||
chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
temperature: temperature,
|
||||
n: 1,
|
||||
messages: [
|
||||
systemPrompt,
|
||||
{
|
||||
role: 'user',
|
||||
content: dataItem?.text || ''
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
rawContent: res?.data.choices[0].message?.content || '',
|
||||
result: splitText(res?.data.choices[0].message?.content || '')
|
||||
})) // 从 content 中提取 QA
|
||||
)
|
||||
);
|
||||
// 过滤出成功的响应
|
||||
const successResponse: {
|
||||
rawContent: string;
|
||||
result: { q: string; a: string }[];
|
||||
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
|
||||
|
||||
const rawContents = successResponse.map((item) => item.rawContent);
|
||||
const results = successResponse.map((item) => item.result).flat();
|
||||
|
||||
// 插入数据库,并修改状态
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: 0,
|
||||
$push: {
|
||||
rawResponse: {
|
||||
$each: successResponse.map((item) => item.rawContent)
|
||||
const response = await chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: ChatModelNameEnum.GPT35,
|
||||
temperature: 0.2,
|
||||
n: 1,
|
||||
messages: [
|
||||
systemPrompt,
|
||||
{
|
||||
role: 'user',
|
||||
content: text
|
||||
}
|
||||
]
|
||||
},
|
||||
result: {
|
||||
$each: results
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
}
|
||||
});
|
||||
)
|
||||
.then((res) => ({
|
||||
rawContent: res?.data.choices[0].message?.content || '',
|
||||
result: splitText(res?.data.choices[0].message?.content || '')
|
||||
})); // 从 content 中提取 QA
|
||||
|
||||
// 插入 modelData 表,生成向量
|
||||
await ModelData.insertMany(
|
||||
response.result.map((item) => ({
|
||||
modelId: dataItem.modelId,
|
||||
userId: dataItem.userId,
|
||||
text: item.a,
|
||||
q: [
|
||||
{
|
||||
id: nanoid(),
|
||||
text: item.q
|
||||
}
|
||||
],
|
||||
status: 1
|
||||
}))
|
||||
);
|
||||
|
||||
console.log(
|
||||
'生成QA成功,time:',
|
||||
`${(Date.now() - startTime) / 1000}s`,
|
||||
'QA数量:',
|
||||
results.length
|
||||
response.result.length
|
||||
);
|
||||
|
||||
// 计费
|
||||
pushSplitDataBill({
|
||||
isPay: !userApiKey && results.length > 0,
|
||||
isPay: !userApiKey && response.result.length > 0,
|
||||
userId: dataItem.userId,
|
||||
type: 'QA',
|
||||
text: systemPrompt.content + dataItem.text + rawContents.join('')
|
||||
text: systemPrompt.content + text + response.rawContent
|
||||
});
|
||||
} catch (error: any) {
|
||||
console.log('error: 生成QA错误', dataItem?._id);
|
||||
console.log('response:', error?.response);
|
||||
if (dataItem?._id) {
|
||||
await DataItem.findByIdAndUpdate(dataItem._id, {
|
||||
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
|
||||
$inc: {
|
||||
// 剩余尝试次数-1
|
||||
times: -1
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
generateQA(true);
|
||||
generateQA(true);
|
||||
generateVector(true);
|
||||
} catch (error: any) {
|
||||
console.log(error);
|
||||
console.log('生成QA错误:', error?.response);
|
||||
|
||||
setTimeout(() => {
|
||||
generateQA(true);
|
||||
}, 10000);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
88
src/service/events/generateVector.ts
Normal file
88
src/service/events/generateVector.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
import { getOpenAIApi } from '@/service/utils/chat';
|
||||
import { httpsAgent } from '@/service/utils/tools';
|
||||
import { ModelData } from '../models/modelData';
|
||||
import { connectRedis } from '../redis';
|
||||
import { VecModelDataIndex } from '@/constants/redis';
|
||||
|
||||
export async function generateVector(next = false): Promise<any> {
|
||||
if (global.generatingVector && !next) return;
|
||||
global.generatingVector = true;
|
||||
|
||||
try {
|
||||
const redis = await connectRedis();
|
||||
|
||||
// 找出一个需要生成的 dataItem
|
||||
const dataItem = await ModelData.findOne({
|
||||
status: { $ne: 0 }
|
||||
});
|
||||
|
||||
if (!dataItem) {
|
||||
console.log('没有需要生成 【向量】 的数据');
|
||||
global.generatingVector = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取 openapi Key
|
||||
const openAiKey = process.env.OPENAIKEY as string;
|
||||
|
||||
// 获取 openai 请求实例
|
||||
const chatAPI = getOpenAIApi(openAiKey);
|
||||
|
||||
const dataId = String(dataItem._id);
|
||||
|
||||
// 生成词向量
|
||||
const response = await Promise.allSettled(
|
||||
dataItem.q.map((item, i) =>
|
||||
chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: 'text-embedding-ada-002',
|
||||
input: item.text
|
||||
},
|
||||
{
|
||||
timeout: 120000,
|
||||
httpsAgent
|
||||
}
|
||||
)
|
||||
.then((res) => res?.data?.data?.[0]?.embedding || [])
|
||||
.then((vector) =>
|
||||
redis.sendCommand([
|
||||
'JSON.SET',
|
||||
`${VecModelDataIndex}:${dataId}:${i}`,
|
||||
'$',
|
||||
JSON.stringify({
|
||||
dataId,
|
||||
modelId: String(dataItem.modelId),
|
||||
vector
|
||||
})
|
||||
])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
|
||||
throw new Error(JSON.stringify(response));
|
||||
}
|
||||
// 修改该数据状态
|
||||
await ModelData.findByIdAndUpdate(dataItem._id, {
|
||||
status: 0
|
||||
});
|
||||
|
||||
console.log(`生成向量成功: ${dataItem._id}`);
|
||||
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
}, 3000);
|
||||
} catch (error: any) {
|
||||
console.log(error);
|
||||
console.log('error: 生成向量错误', error?.response?.data);
|
||||
|
||||
if (error?.response?.statusText === 'Too Many Requests') {
|
||||
console.log('次数限制,1分钟后尝试');
|
||||
// 限制次数,1分钟后再试
|
||||
setTimeout(() => {
|
||||
generateVector(true);
|
||||
}, 60000);
|
||||
}
|
||||
}
|
||||
}
|
@@ -34,7 +34,7 @@ export const pushChatBill = async ({
|
||||
// 计算价格
|
||||
const unitPrice = modelItem?.price || 5;
|
||||
const price = unitPrice * tokens.length;
|
||||
console.log(`chat bill, price: ${formatPrice(price)}元`);
|
||||
console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}元`);
|
||||
|
||||
try {
|
||||
// 插入 Bill 记录
|
||||
|
@@ -13,22 +13,23 @@ const ModelDataSchema = new Schema({
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
q: {
|
||||
text: {
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
a: {
|
||||
type: String,
|
||||
default: ''
|
||||
q: {
|
||||
type: [
|
||||
{
|
||||
id: String, // 对应redis的key
|
||||
text: String
|
||||
}
|
||||
],
|
||||
default: []
|
||||
},
|
||||
status: {
|
||||
type: Number,
|
||||
enum: [0, 1, 2],
|
||||
enum: [0, 1], // 1 训练ing
|
||||
default: 1
|
||||
},
|
||||
createTime: {
|
||||
type: Date,
|
||||
default: () => new Date()
|
||||
}
|
||||
});
|
||||
|
||||
|
31
src/service/models/splitData.ts
Normal file
31
src/service/models/splitData.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
/* 模型的知识库 */
|
||||
import { Schema, model, models, Model as MongoModel } from 'mongoose';
|
||||
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
|
||||
|
||||
const SplitDataSchema = new Schema({
|
||||
userId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'user',
|
||||
required: true
|
||||
},
|
||||
modelId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'model',
|
||||
required: true
|
||||
},
|
||||
rawText: {
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
textList: {
|
||||
type: [String],
|
||||
default: []
|
||||
},
|
||||
errorText: {
|
||||
type: String,
|
||||
default: ''
|
||||
}
|
||||
});
|
||||
|
||||
export const SplitData: MongoModel<SplitDataType> =
|
||||
models['splitData'] || model('splitData', SplitDataSchema);
|
@@ -1,6 +1,7 @@
|
||||
import mongoose from 'mongoose';
|
||||
import { generateQA } from './events/generateQA';
|
||||
import { generateAbstract } from './events/generateAbstract';
|
||||
import { generateVector } from './events/generateVector';
|
||||
|
||||
/**
|
||||
* 连接 MongoDB 数据库
|
||||
@@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
|
||||
}
|
||||
|
||||
generateQA();
|
||||
generateAbstract();
|
||||
// generateAbstract();
|
||||
generateVector();
|
||||
}
|
||||
|
||||
export * from './models/authCode';
|
||||
@@ -40,3 +42,4 @@ export * from './models/bill';
|
||||
export * from './models/pay';
|
||||
export * from './models/data';
|
||||
export * from './models/dataItem';
|
||||
export * from './models/splitData';
|
||||
|
45
src/service/redis.ts
Normal file
45
src/service/redis.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { createClient } from 'redis';
|
||||
import { customAlphabet } from 'nanoid';
|
||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
|
||||
|
||||
export const connectRedis = async () => {
|
||||
// 断开了,重连
|
||||
if (global.redisClient && !global.redisClient.isOpen) {
|
||||
await global.redisClient.disconnect();
|
||||
} else if (global.redisClient) {
|
||||
// 没断开,不再连接
|
||||
return global.redisClient;
|
||||
}
|
||||
|
||||
try {
|
||||
global.redisClient = createClient({
|
||||
url: process.env.REDIS_URL
|
||||
});
|
||||
|
||||
global.redisClient.on('error', (err) => {
|
||||
console.log('Redis Client Error', err);
|
||||
global.redisClient = null;
|
||||
});
|
||||
global.redisClient.on('end', () => {
|
||||
global.redisClient = null;
|
||||
});
|
||||
global.redisClient.on('ready', () => {
|
||||
console.log('redis connected');
|
||||
});
|
||||
|
||||
await global.redisClient.connect();
|
||||
|
||||
// 0 - 测试库,1 - 正式
|
||||
await global.redisClient.select(0);
|
||||
|
||||
return global.redisClient;
|
||||
} catch (error) {
|
||||
console.log(error, '==');
|
||||
global.redisClient = null;
|
||||
return Promise.reject('redis 连接失败');
|
||||
}
|
||||
};
|
||||
|
||||
export const getKey = (prefix = '') => {
|
||||
return `${prefix}:${nanoid()}`;
|
||||
};
|
@@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
|
||||
|
||||
return systemPrompt ? [systemPrompt, ...res] : res;
|
||||
};
|
||||
|
||||
/* system 内容截断 */
|
||||
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
|
||||
let splitText = '';
|
||||
|
||||
// 从前往前截取
|
||||
for (let i = 0; i < prompts.length; i++) {
|
||||
const prompt = prompts[i];
|
||||
|
||||
splitText += `${prompt}\n`;
|
||||
const tokens = encode(splitText).length;
|
||||
if (tokens >= maxTokens) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return splitText;
|
||||
};
|
||||
|
3
src/types/index.d.ts
vendored
3
src/types/index.d.ts
vendored
@@ -1,9 +1,12 @@
|
||||
import type { Mongoose } from 'mongoose';
|
||||
import type { RedisClientType } from 'redis';
|
||||
|
||||
declare global {
|
||||
var mongodb: Mongoose | string | null;
|
||||
var redisClient: RedisClientType | null;
|
||||
var generatingQA: boolean;
|
||||
var generatingAbstract: boolean;
|
||||
var generatingVector: boolean;
|
||||
var QRCode: any;
|
||||
interface Window {
|
||||
['pdfjs-dist/build/pdf']: any;
|
||||
|
9
src/types/model.d.ts
vendored
9
src/types/model.d.ts
vendored
@@ -8,3 +8,12 @@ export interface ModelUpdateParams {
|
||||
service: ModelSchema.service;
|
||||
security: ModelSchema.security;
|
||||
}
|
||||
|
||||
export interface ModelDataItemType {
|
||||
id: string;
|
||||
status: 0 | 1; // 1代表向量生成完毕
|
||||
q: string; // 提问词
|
||||
a: string; // 原文
|
||||
modelId: string;
|
||||
userId: string;
|
||||
}
|
||||
|
22
src/types/mongoSchema.d.ts
vendored
22
src/types/mongoSchema.d.ts
vendored
@@ -51,12 +51,26 @@ export interface ModelPopulate extends ModelSchema {
|
||||
userId: UserModelSchema;
|
||||
}
|
||||
|
||||
export type ModelDataType = 0 | 1;
|
||||
export interface ModelDataSchema {
|
||||
_id: string;
|
||||
q: string;
|
||||
a: string;
|
||||
status: 0 | 1 | 2;
|
||||
createTime: Date;
|
||||
modelId: string;
|
||||
userId: string;
|
||||
text: string;
|
||||
q: {
|
||||
id: string;
|
||||
text: string;
|
||||
}[];
|
||||
status: ModelDataType;
|
||||
}
|
||||
|
||||
export interface ModelSplitDataSchema {
|
||||
_id: string;
|
||||
userId: string;
|
||||
modelId: string;
|
||||
rawText: string;
|
||||
errorText: string;
|
||||
textList: string[];
|
||||
}
|
||||
|
||||
export interface TrainingSchema {
|
||||
|
6
src/types/redis.d.ts
vendored
Normal file
6
src/types/redis.d.ts
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
export interface RedisModelDataItemType {
|
||||
id: string;
|
||||
vector: number[];
|
||||
dataId: string;
|
||||
modelId: string;
|
||||
}
|
@@ -124,3 +124,15 @@ export const readDocContent = (file: File) =>
|
||||
reject('读取 doc 文件失败');
|
||||
};
|
||||
});
|
||||
|
||||
export const vectorToBuffer = (vector: number[]) => {
|
||||
const float32Arr = new Float32Array(vector);
|
||||
const myBuffer = new ArrayBuffer(float32Arr.length * Float32Array.BYTES_PER_ELEMENT);
|
||||
const myView = new DataView(myBuffer);
|
||||
|
||||
for (let i = 0; i < float32Arr.length; i++) {
|
||||
myView.setFloat32(i * Float32Array.BYTES_PER_ELEMENT, float32Arr[i], true);
|
||||
}
|
||||
|
||||
return Buffer.from(myBuffer);
|
||||
};
|
||||
|
Reference in New Issue
Block a user