feat: 模型数据导入

This commit is contained in:
archer
2023-03-30 01:04:52 +08:00
parent f32c557bdd
commit 2b2c70e53d
16 changed files with 415 additions and 76 deletions

View File

@@ -1,5 +1,5 @@
import { GET, POST, DELETE, PUT } from './request'; import { GET, POST, DELETE, PUT } from './request';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema';
import { ModelUpdateParams } from '@/types/model'; import { ModelUpdateParams } from '@/types/model';
import { TrainingItemType } from '../types/training'; import { TrainingItemType } from '../types/training';
import { PagingData } from '@/types'; import { PagingData } from '@/types';
@@ -39,10 +39,15 @@ type GetModelDataListProps = RequestPaging & {
export const getModelDataList = (props: GetModelDataListProps) => export const getModelDataList = (props: GetModelDataListProps) =>
GET(`/model/data/getModelData?${Obj2Query(props)}`); GET(`/model/data/getModelData?${Obj2Query(props)}`);
export const postModelData = (data: { modelId: string; data: { q: string; a: string }[] }) => export const postModelDataInput = (data: {
POST(`/model/data/pushModelData`, data); modelId: string;
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data);
export const putModelDataById = (data: { modelId: string; answer: string }) => export const postModelDataSelect = (modelId: string, dataIds: string[]) =>
POST(`/model/data/pushModelDataSelectData`, { modelId, dataIds });
export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data); PUT('/model/data/putModelData', data);
export const DelOneModelData = (modelId: string) => export const delOneModelData = (dataId: string) =>
DELETE(`/model/data/delModelDataById?modelId=${modelId}`); DELETE(`/model/data/delModelDataById?dataId=${dataId}`);

View File

@@ -75,10 +75,9 @@ export const formatModelStatus = {
} }
}; };
export const ModelDataStatusMap: Record<ModelDataType, string> = { export const ModelDataStatusMap = {
0: '训练完成', 0: '训练完成',
1: '等待训练', 1: '训练'
2: '训练中'
}; };
export const defaultModel: ModelSchema = { export const defaultModel: ModelSchema = {

2
src/constants/redis.ts Normal file
View File

@@ -0,0 +1,2 @@
export const ModelDataIndex = 'model:data';
export const VecModelDataIndex = 'vec:model:data';

View File

@@ -75,6 +75,7 @@ export const usePaging = <T = any>({
requesting, requesting,
isLoadAll, isLoadAll,
nextPage, nextPage,
initRequesting initRequesting,
setData
}; };
}; };

View File

@@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
let { modelId } = req.query as { let { dataId } = req.query as {
modelId: string; dataId: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!modelId) { if (!dataId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
await ModelData.deleteOne({ await ModelData.deleteOne({
modelId, _id: dataId,
userId userId
}); });

View File

@@ -2,12 +2,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { modelId, data } = req.body as { const { modelId, data } = req.body as {
modelId: string; modelId: string;
data: { q: string; a: string }[]; data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}; };
const { authorization } = req.headers; const { authorization } = req.headers;

View File

@@ -0,0 +1,57 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
let { modelId, answer } = req.body as { let { dataId, text } = req.body as {
modelId: string; dataId: string;
answer: string; text: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!modelId) { if (!dataId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await ModelData.updateOne( await ModelData.updateOne(
{ {
modelId, _id: dataId,
userId userId
}, },
{ {
a: answer text
} }
); );

View File

@@ -10,30 +10,134 @@ import {
Td, Td,
IconButton, IconButton,
Flex, Flex,
Button Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
Checkbox,
CheckboxGroup,
ModalCloseButton,
useDisclosure,
Input,
Textarea,
Stack
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema'; import { ModelDataSchema } from '@/types/mongoSchema';
import { ModelDataStatusMap } from '@/constants/model'; import { ModelDataStatusMap } from '@/constants/model';
import { usePaging } from '@/hooks/usePaging'; import { usePaging } from '@/hooks/usePaging';
import ScrollData from '@/components/ScrollData'; import ScrollData from '@/components/ScrollData';
import { getModelDataList } from '@/api/model'; import {
getModelDataList,
postModelDataInput,
postModelDataSelect,
delOneModelData,
putModelDataById
} from '@/api/model';
import { getDataList } from '@/api/data';
import { DeleteIcon } from '@chakra-ui/icons'; import { DeleteIcon } from '@chakra-ui/icons';
import { useForm, useFieldArray } from 'react-hook-form';
import { useToast } from '@/hooks/useToast';
import { useQuery } from '@tanstack/react-query';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
type TabType = 'input' | 'select';
const defaultValues = {
text: '',
q: [{ val: '' }]
};
const ModelDataCard = ({ model }: { model: ModelSchema }) => { const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { const {
nextPage, nextPage,
isLoadAll, isLoadAll,
requesting, requesting,
data: dataList, data: modelDataList,
total total,
setData,
getData
} = usePaging<ModelDataSchema>({ } = usePaging<ModelDataSchema>({
api: getModelDataList, api: getModelDataList,
pageSize: 10, pageSize: 20,
params: { params: {
modelId: model._id modelId: model._id
} }
}); });
const { toast } = useToast();
const {
isOpen: isOpenImportModal,
onOpen: onOpenImportModal,
onClose: onCloseImportModal
} = useDisclosure();
const { register, handleSubmit, reset, control } = useForm<FormData>({
defaultValues
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
});
const importDataTypes: { id: TabType; label: string }[] = [
{ id: 'input', label: '手动输入' },
{ id: 'select', label: '数据集导入' }
];
const [importDataType, setImportDataType] = useState<TabType>(importDataTypes[0].id);
const [importing, setImporting] = useState(false);
const updateAnswer = useCallback(async (dataId: string, text: string) => {
putModelDataById({
dataId,
text
});
}, []);
const { data: dataList = [] } = useQuery(['getDataList'], getDataList);
const [selectDataId, setSelectDataId] = useState<string[]>([]);
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
if (importDataType === 'input') {
await postModelDataInput({
modelId: model._id,
data: [
{
text: e.text,
q: e.q.map((item) => ({
id: nanoid(),
text: item.val
}))
}
]
});
} else if (importDataType === 'select') {
const res = await postModelDataSelect(model._id, selectDataId);
console.log(res);
}
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onCloseImportModal();
getData(1, true);
reset(defaultValues);
} catch (err) {
console.log(err);
}
setImporting(false);
},
[getData, importDataType, model._id, onCloseImportModal, reset, toast]
);
return ( return (
<> <>
@@ -41,18 +145,17 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}> <Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
: {total} : {total}
</Box> </Box>
<Button size={'sm'}></Button> <Button size={'sm'} onClick={onOpenImportModal}>
</Button>
</Flex> </Flex>
<ScrollData <ScrollData
flex={'1 0 0'} h={'100%'}
h={0}
px={6} px={6}
mt={3} mt={3}
isLoadAll={isLoadAll} isLoadAll={isLoadAll}
requesting={requesting} requesting={requesting}
nextPage={nextPage} nextPage={nextPage}
fontSize={'xs'}
whiteSpace={'pre-wrap'}
> >
<TableContainer mt={4}> <TableContainer mt={4}>
<Table variant={'simple'}> <Table variant={'simple'}>
@@ -65,13 +168,55 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
</Tr> </Tr>
</Thead> </Thead>
<Tbody> <Tbody>
{dataList.map((item) => ( {modelDataList.map((item) => (
<Tr key={item._id}> <Tr key={item._id}>
<Td>{item.q}</Td> <Td>
<Td>{item.a}</Td> {item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
maxW={'350px'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}: {item.text}
</Box>
))}
</Td>
<Td w={'350px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
setData((state) =>
state.map((data) => ({
...data,
text: data._id === item._id ? e.target.value : data.text
}))
);
}
}}
></Textarea>
</Td>
<Td>{ModelDataStatusMap[item.status]}</Td> <Td>{ModelDataStatusMap[item.status]}</Td>
<Td> <Td>
<IconButton icon={<DeleteIcon />} aria-label={'delete'} /> <IconButton
icon={<DeleteIcon />}
variant={'outline'}
colorScheme={'gray'}
aria-label={'delete'}
size={'sm'}
onClick={async () => {
delOneModelData(item._id);
setData((state) => state.filter((data) => data._id !== item._id));
}}
/>
</Td> </Td>
</Tr> </Tr>
))} ))}
@@ -79,6 +224,101 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
</Table> </Table>
</TableContainer> </TableContainer>
</ScrollData> </ScrollData>
<Modal isOpen={isOpenImportModal} onClose={onCloseImportModal}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<Flex alignItems={'center'}>
<ModalHeader whiteSpace={'nowrap'}></ModalHeader>
<Box>
{importDataTypes.map((item) => (
<Button
key={item.id}
size={'sm'}
mr={5}
variant={item.id === importDataType ? 'solid' : 'outline'}
onClick={() => setImportDataType(item.id)}
>
{item.label}
</Button>
))}
</Box>
</Flex>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
{importDataType === 'input' && (
<>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
</>
)}
{importDataType === 'select' && (
<CheckboxGroup colorScheme="blue" onChange={(e) => setSelectDataId(e as string[])}>
{dataList.map((item) => (
<Box mb={2} key={item._id}>
<Checkbox value={item._id}>
<Box fontWeight={'bold'} as={'span'}>
{item.name}
</Box>
<Box as={'span'} ml={2} fontSize={'sm'}>
({item.totalData})
</Box>
</Checkbox>
</Box>
))}
</CheckboxGroup>
)}
</Box>
<Flex px={6} pt={2} pb={4}>
{importDataType === 'input' && (
<Button
alignSelf={'flex-start'}
variant={'outline'}
onClick={() => appendQ({ val: '' })}
>
</Button>
)}
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onCloseImportModal}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
</Button>
</Flex>
</ModalContent>
</Modal>
</> </>
); );
}; };

View File

@@ -255,7 +255,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
<Training model={model} /> <Training model={model} />
</Card> </Card>
)} */} )} */}
<Card p={4} height={'400px'} gridColumnStart={1} gridColumnEnd={3}> <Card p={4} height={'500px'} gridColumnStart={1} gridColumnEnd={3}>
{model._id && <ModelDataCard model={model} />} {model._id && <ModelDataCard model={model} />}
</Card> </Card>
</Grid> </Grid>

View File

@@ -13,22 +13,23 @@ const ModelDataSchema = new Schema({
ref: 'user', ref: 'user',
required: true required: true
}, },
q: { text: {
type: String, type: String,
required: true required: true
}, },
a: { q: {
type: String, type: [
default: '' {
id: String, // 对应redis的key
text: String
}
],
default: []
}, },
status: { status: {
type: Number, type: Number,
enum: [0, 1, 2], enum: [0, 1], // 1 训练ing
default: 1 default: 1
},
createTime: {
type: Date,
default: () => new Date()
} }
}); });

View File

@@ -1,7 +1,6 @@
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { generateQA } from './events/generateQA'; import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract'; import { generateAbstract } from './events/generateAbstract';
import { connectRedis } from './redis';
/** /**
* 连接 MongoDB 数据库 * 连接 MongoDB 数据库
@@ -29,7 +28,6 @@ export async function connectToDatabase(): Promise<void> {
generateQA(); generateQA();
generateAbstract(); generateAbstract();
// connectRedis();
} }
export * from './models/authCode'; export * from './models/authCode';

View File

@@ -1,4 +1,7 @@
import { createClient, SchemaFieldTypes } from 'redis'; import { createClient, SchemaFieldTypes } from 'redis';
import { ModelDataIndex } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
export const connectRedis = async () => { export const connectRedis = async () => {
// 断开了,重连 // 断开了,重连
@@ -6,12 +9,12 @@ export const connectRedis = async () => {
await global.redisClient.disconnect(); await global.redisClient.disconnect();
} else if (global.redisClient) { } else if (global.redisClient) {
// 没断开,不再连接 // 没断开,不再连接
return; return global.redisClient;
} }
try { try {
global.redisClient = createClient({ global.redisClient = createClient({
url: 'redis://:121914yu@120.76.193.200:8100' url: 'redis://default:121914yu@120.76.193.200:8100'
}); });
global.redisClient.on('error', (err) => { global.redisClient.on('error', (err) => {
@@ -31,31 +34,35 @@ export const connectRedis = async () => {
await global.redisClient.select(0); await global.redisClient.select(0);
// 创建索引 // 创建索引
await global.redisClient.ft.create( try {
'vec:question', await global.redisClient.ft.create(
{ ModelDataIndex,
'$.vector': SchemaFieldTypes.VECTOR, {
'$.modelId': { // '$.vector': SchemaFieldTypes.VECTOR,
type: SchemaFieldTypes.TEXT, '$.modelId': {
AS: 'modelId' type: SchemaFieldTypes.TEXT,
AS: 'modelId'
},
'$.userId': {
type: SchemaFieldTypes.TEXT,
AS: 'userId'
},
'$.status': {
type: SchemaFieldTypes.NUMERIC,
AS: 'status'
}
}, },
'$.userId': { {
type: SchemaFieldTypes.TEXT, ON: 'JSON',
AS: 'userId' PREFIX: 'model:data'
},
'$.status': {
type: SchemaFieldTypes.NUMERIC,
AS: 'status'
} }
}, );
{ } catch (error) {
ON: 'JSON', console.log('创建索引失败', error);
PREFIX: 'fastgpt:modeldata' }
}
);
// await global.redisClient.json.set('fastgpt:modeldata:1', '$', { // await global.redisClient.json.set('fastgpt:modeldata:2', '$', {
// vector: [], // vector: [124, 214, 412, 4, 124, 1, 4, 1, 4, 3, 423],
// modelId: 'daf', // modelId: 'daf',
// userId: 'adfd', // userId: 'adfd',
// q: 'fasf', // q: 'fasf',
@@ -63,12 +70,17 @@ export const connectRedis = async () => {
// status: 0, // status: 0,
// createTime: new Date() // createTime: new Date()
// }); // });
// const value = await global.redisClient.get('fastgpt:modeldata:1'); // const value = await global.redisClient.json.get('fastgpt:modeldata:2');
// console.log(value); // console.log(value);
return global.redisClient;
} catch (error) { } catch (error) {
console.log(error, '=='); console.log(error, '==');
global.redisClient = null; global.redisClient = null;
return Promise.reject('redis 连接失败'); return Promise.reject('redis 连接失败');
} }
}; };
export const getKey = (prefix = '') => {
return `${prefix}:${nanoid()}`;
};

View File

@@ -8,3 +8,12 @@ export interface ModelUpdateParams {
service: ModelSchema.service; service: ModelSchema.service;
security: ModelSchema.security; security: ModelSchema.security;
} }
export interface ModelDataItemType {
id: string;
status: 0 | 1; // 1代表向量生成完毕
q: string; // 提问词
a: string; // 原文
modelId: string;
userId: string;
}

View File

@@ -51,13 +51,17 @@ export interface ModelPopulate extends ModelSchema {
userId: UserModelSchema; userId: UserModelSchema;
} }
export type ModelDataType = 0 | 1 | 2; export type ModelDataType = 0 | 1;
export interface ModelDataSchema { export interface ModelDataSchema {
_id: string; _id: string;
q: string; modelId: string;
a: string; userId: string;
text: string;
q: {
id: string;
text: string;
}[];
status: ModelDataType; status: ModelDataType;
createTime: Date;
} }
export interface TrainingSchema { export interface TrainingSchema {

10
src/types/redis.d.ts vendored Normal file
View File

@@ -0,0 +1,10 @@
export interface RedisModelDataItemType {
id: string;
value: {
vector: number[];
q: string; // 提问词
a: string; // 原文
modelId: string;
userId: string;
};
}