This commit is contained in:
archer
2023-03-30 22:33:58 +08:00
36 changed files with 1187 additions and 232 deletions

View File

@@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210
MONGODB_URI=
MY_MAIL=
MAILE_CODE=
TOKEN_KEY=
TOKEN_KEY=
OPENAIKEY=
REDIS_URL=

View File

@@ -44,8 +44,8 @@ export const postModelDataInput = (data: {
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data);
export const postModelDataSelect = (modelId: string, dataIds: string[]) =>
POST(`/model/data/pushModelDataSelectData`, { modelId, dataIds });
export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/splitData`, { modelId, text });
export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data);

View File

@@ -26,12 +26,12 @@ const navbarList = [
link: '/model/list',
activeLink: ['/model/list', '/model/detail']
},
{
label: '数据',
icon: 'icon-datafull',
link: '/data/list',
activeLink: ['/data/list', '/data/detail']
},
// {
// label: '数据',
// icon: 'icon-datafull',
// link: '/data/list',
// activeLink: ['/data/list', '/data/detail']
// },
{
label: '账号',
icon: 'icon-yonghu-yuan',

View File

@@ -2,9 +2,16 @@ import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchem
export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003'
}
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
};
export type ModelConstantsData = {
serviceCompany: `${ServiceName}`;
name: string;
@@ -28,6 +35,17 @@ export const modelList: ModelConstantsData[] = [
trainedMaxToken: 2000,
maxTemperature: 2,
price: 3
},
{
serviceCompany: 'openai',
name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT,
trainName: 'vector',
maxToken: 4000,
contextMaxToken: 7500,
trainedMaxToken: 2000,
maxTemperature: 1,
price: 3
}
// {
// serviceCompany: 'openai',

View File

@@ -1,2 +1 @@
export const ModelDataIndex = 'model:data';
export const VecModelDataIndex = 'vec:model:data';
export const VecModelDataIndex = 'model:data';

View File

@@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型异常,请用 chatgpt 模型');
throw new Error('模型加载异常');
}
// 读取对话内容

View File

@@ -0,0 +1,241 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
const redis = await connectRedis();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...chat.content, prompt];
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 把输入的内容转成向量
const promptVector = await chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: prompt.value
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
const binary = vectorToBuffer(promptVector);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataIndex}`,
`@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'dataId',
// 'SORTBY',
// 'score',
'PARAMS',
'2',
'blob',
binary,
'DIALECT',
'2'
]);
// 格式化响应值获取去重后的id
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return redisData[i][1];
})
.filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text')
.then((res) => res?.text || '');
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 2800);
prompts.unshift({
obj: 'SYSTEM',
value: `请根据下面的知识回答问题: ${systemPrompt}`
});
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
let startTime = Date.now();
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
},
{
timeout: 40000,
responseType: 'stream',
httpsAgent
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = '';
stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
// console.log('content:', content)
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!DataRecord) {
throw new Error('找不到数据集');
}
const replaceText = text.replace(/[\r\n\\n]+/g, ' ');
const replaceText = text.replace(/[\\n]+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
@@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
if (tokens >= 780) {
dataItems.push({
userId,
dataId,

View File

@@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model';
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 重名校验
const authRepeatName = await Model.findOne({
name,
userId
});
if (authRepeatName) {
throw new Error('模型名重复');
}
// 上限校验
const authCount = await Model.countDocuments({
userId
@@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running,
service: {
company: modelItem.serviceCompany,
trainId: modelItem.trainName,
chatModel: modelItem.model,
modelName: modelItem.model
trainId: '',
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
});

View File

@@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -44,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}))
);
generateVector(true);
jsonRes(res, {
data: model
});

View File

@@ -0,0 +1,67 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
textList.push(splitText);
splitText = '';
}
});
// 批量插入数据
await SplitData.create({
userId,
modelId,
rawText: text,
textList
});
// generateQA();
jsonRes(res, {
data: { chunks, replaceText }
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
@@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 删除模型
await Model.deleteOne({
_id: modelId,
userId
});
let requestQueue: any[] = [];
// 删除对应的聊天
await Chat.deleteMany({
modelId
});
requestQueue.push(
Chat.deleteMany({
modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({
@@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}
// 删除对应训练记录
await Training.deleteMany({
modelId
});
requestQueue.push(
Training.deleteMany({
modelId
})
);
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await requestQueue;
jsonRes(res);
} catch (err) {

View File

@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt,
intro,
temperature,
service,
// service,
security
}
);

View File

@@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
};

View File

@@ -184,7 +184,7 @@ const DataList = () => {
>
</Button>
<Menu>
{/* <Menu>
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
导出
</MenuButton>
@@ -200,7 +200,7 @@ const DataList = () => {
</MenuItem>
)}
</MenuList>
</Menu>
</Menu> */}
<Button
size={'sm'}

View File

@@ -0,0 +1,141 @@
import React, { useState, useCallback } from 'react';
import {
Box,
IconButton,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
Input,
Textarea
} from '@chakra-ui/react';
import { useForm, useFieldArray } from 'react-hook-form';
import { postModelDataInput } from '@/api/model';
import { useToast } from '@/hooks/useToast';
import { DeleteIcon } from '@chakra-ui/icons';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
const InputDataModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [importing, setImporting] = useState(false);
const { toast } = useToast();
const { register, handleSubmit, control } = useForm<FormData>({
defaultValues: {
text: '',
q: [{ val: '' }]
}
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
});
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
await postModelDataInput({
modelId: modelId,
data: [
{
text: e.text,
q: e.q.map((item) => ({
id: nanoid(),
text: item.val
}))
}
]
});
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onClose();
onSuccess();
} catch (err) {
console.log(err);
}
setImporting(false);
},
[modelId, onClose, onSuccess, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
</Box>
<Flex px={6} pt={2} pb={4}>
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
</Button>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
</Button>
</Flex>
</ModalContent>
</Modal>
);
};
export default InputDataModal;

View File

@@ -0,0 +1,202 @@
import React, { useCallback } from 'react';
import {
Box,
TableContainer,
Table,
Thead,
Tbody,
Tr,
Th,
Td,
IconButton,
Flex,
Button,
useDisclosure,
Textarea,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema';
import { ModelDataStatusMap } from '@/constants/model';
import { usePaging } from '@/hooks/usePaging';
import ScrollData from '@/components/ScrollData';
import { getModelDataList, delOneModelData, putModelDataById } from '@/api/model';
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast();
const { Loading } = useLoading();
const {
nextPage,
isLoadAll,
requesting,
data: modelDataList,
total,
setData,
getData
} = usePaging<ModelDataSchema>({
api: getModelDataList,
pageSize: 20,
params: {
modelId: model._id
}
});
const updateAnswer = useCallback(
async (dataId: string, text: string) => {
await putModelDataById({
dataId,
text
});
toast({
title: '修改回答成功',
status: 'success'
});
},
[toast]
);
const {
isOpen: isOpenInputModal,
onOpen: onOpenInputModal,
onClose: onCloseInputModal
} = useDisclosure();
const {
isOpen: isOpenSelectModal,
onOpen: onOpenSelectModal,
onClose: onCloseSelectModal
} = useDisclosure();
return (
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
: {total}{' '}
<Box as={'span'} fontSize={'sm'}>
</Box>
</Box>
<IconButton
icon={<RepeatIcon />}
aria-label={'refresh'}
variant={'outline'}
mr={4}
onClick={() => getData(1, true)}
/>
<Menu>
<MenuButton as={Button}></MenuButton>
<MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem>
</MenuList>
</Menu>
</Flex>
<ScrollData
h={'100%'}
px={6}
mt={3}
isLoadAll={isLoadAll}
requesting={requesting}
nextPage={nextPage}
position={'relative'}
>
<TableContainer mt={4}>
<Table variant={'simple'}>
<Thead>
<Tr>
<Th>Question</Th>
<Th>Text</Th>
<Th>Status</Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item._id}>
<Td w={'350px'}>
{item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}:{' '}
<Box as={'span'} userSelect={'all'}>
{item.text}
</Box>
</Box>
))}
</Td>
<Td minW={'200px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
setData((state) =>
state.map((data) => ({
...data,
text: data._id === item._id ? e.target.value : data.text
}))
);
}
}}
></Textarea>
</Td>
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
<Td>
<IconButton
icon={<DeleteIcon />}
variant={'outline'}
colorScheme={'gray'}
aria-label={'delete'}
size={'sm'}
onClick={async () => {
delOneModelData(item._id);
setData((state) => state.filter((data) => data._id !== item._id));
}}
/>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
<Loading loading={requesting} fixed={false} />
</ScrollData>
{isOpenInputModal && (
<InputModel
modelId={model._id}
onClose={onCloseInputModal}
onSuccess={() => getData(1, true)}
/>
)}
{isOpenSelectModal && (
<SelectModel
modelId={model._id}
onClose={onCloseSelectModal}
onSuccess={() => getData(1, true)}
/>
)}
</>
);
};
export default ModelDataCard;

View File

@@ -23,9 +23,11 @@ import { useConfirm } from '@/hooks/useConfirm';
const ModelEditForm = ({
formHooks,
canTrain,
handleDelModel
}: {
formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
handleDelModel: () => void;
}) => {
const { openConfirm, ConfirmChild } = useConfirm({
@@ -136,15 +138,24 @@ const ModelEditForm = ({
</Flex>
</FormControl>
<Box mt={4}>
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
{canTrain ? (
<Box fontWeight={'bold'}>
prompt
使 tokens
</Box>
) : (
<>
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</>
)}
</Box>
</Card>
{/* <Card p={4}>

View File

@@ -0,0 +1,155 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { customAlphabet } from 'nanoid';
import { encode } from 'gpt-token-utils';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataFileText } from '@/api/model';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
const SelectFileModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [fileText, setFileText] = useState('');
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const fileTexts = (
await Promise.all(
e.map((file) => {
// @ts-ignore
const extension = file?.name?.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
default:
return '';
}
})
)
)
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
title: typeof error === 'string' ? error : '解析文件失败',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileText) return;
await postModelDataFileText(modelId, fileText);
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</Flex>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectFileModal;

View File

@@ -1,12 +1,6 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router';
import {
getModelById,
delModelById,
postTrainModel,
putModelTrainingStatus,
putModelById
} from '@/api/model';
import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema';
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
@@ -16,12 +10,11 @@ import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/c
import { useGlobalStore } from '@/store/global';
import { useScreen } from '@/hooks/useScreen';
import ModelEditForm from './components/ModelEditForm';
// import Icon from '@/components/Iconfont';
import { useQuery } from '@tanstack/react-query';
// import dynamic from 'next/dynamic';
import ModelDataCard from './components/ModelDataCard';
// const Training = dynamic(() => import('./components/Training'));
const ModelDataCard = dynamic(() => import('./components/ModelDataCard'));
const ModelDetail = ({ modelId }: { modelId: string }) => {
const { toast } = useToast();
@@ -29,16 +22,16 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore();
const SelectFileDom = useRef<HTMLInputElement>(null);
// const SelectFileDom = useRef<HTMLInputElement>(null);
const [model, setModel] = useState<ModelSchema>(defaultModel);
const formHooks = useForm<ModelSchema>({
defaultValues: model
});
// const canTrain = useMemo(() => {
// const openai = modelList.find((item) => item.model === model?.service.modelName);
// return openai && openai.trainName;
// }, [model]);
const canTrain = useMemo(() => {
const openai = modelList.find((item) => item.model === model?.service.modelName);
return !!(openai && openai.trainName);
}, [model]);
/* 加载模型数据 */
const loadModel = useCallback(async () => {
@@ -89,34 +82,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
}, [setLoading, model, router]);
/* 上传数据集,触发微调 */
const startTraining = useCallback(
async (e: React.ChangeEvent<HTMLInputElement>) => {
if (!modelId || !e.target.files || e.target.files?.length === 0) return;
setLoading(true);
try {
const file = e.target.files[0];
const formData = new FormData();
formData.append('file', file);
await postTrainModel(modelId, formData);
// const startTraining = useCallback(
// async (e: React.ChangeEvent<HTMLInputElement>) => {
// if (!modelId || !e.target.files || e.target.files?.length === 0) return;
// setLoading(true);
// try {
// const file = e.target.files[0];
// const formData = new FormData();
// formData.append('file', file);
// await postTrainModel(modelId, formData);
toast({
title: '开始训练...',
status: 'success'
});
// toast({
// title: '开始训练...',
// status: 'success'
// });
// 重新获取模型
loadModel();
} catch (err: any) {
toast({
title: err?.message || '上传文件失败',
status: 'error'
});
console.log('error->', err);
}
setLoading(false);
},
[setLoading, loadModel, modelId, toast]
);
// // 重新获取模型
// loadModel();
// } catch (err: any) {
// toast({
// title: err?.message || '上传文件失败',
// status: 'error'
// });
// console.log('error->', err);
// }
// setLoading(false);
// },
// [setLoading, loadModel, modelId, toast]
// );
/* 点击更新模型状态 */
const handleClickUpdateStatus = useCallback(async () => {
@@ -248,22 +241,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)}
</Card>
<Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}>
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} />
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} canTrain={canTrain} />
{/* {canTrain && (
<Card p={4}>
<Training model={model} />
</Card>
)} */}
<Card p={4} height={'500px'} gridColumnStart={1} gridColumnEnd={3}>
{model._id && <ModelDataCard model={model} />}
</Card>
{canTrain && model._id && (
<Card
p={4}
height={'700px'}
{...media(
{
gridColumnStart: 1,
gridColumnEnd: 3
},
{}
)}
>
<ModelDataCard model={model} />
</Card>
)}
</Grid>
{/* 文件选择 */}
<Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
{/* <Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
<input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} />
</Box>
</Box> */}
</>
);
};

View File

@@ -1,29 +1,26 @@
import { DataItem } from '@/service/mongo';
import { SplitData, ModelData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export async function generateQA(next = false): Promise<any> {
if (process.env.NODE_ENV === 'development') return;
if (global.generatingQA && !next) return;
global.generatingQA = true;
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
};
let dataItem: DataItemSchema | null = null;
try {
// 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({
status: { $ne: 0 },
times: { $gt: 0 },
type: 'QA'
const dataItem = await SplitData.findOne({
textList: { $exists: true, $ne: [] }
});
if (!dataItem) {
@@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
return;
}
// 更新状态为生成中
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 2
});
// 弹出文本
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) {
throw new Error('无文本');
}
// 获取 openapi Key
let userApiKey, systemKey;
@@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
userApiKey = key.userApiKey;
systemKey = key.systemKey;
} catch (error) {
// 余额不够了, 把用户所有记录改成闲置
await DataItem.updateMany({
userId: dataItem.userId,
status: 0
// 余额不够了, 清空该记录
await SplitData.findByIdAndUpdate(dataItem._id, {
textList: [],
errorText: '余额不足,生成数据集任务终止'
});
throw new Error('获取 openai key 失败');
}
@@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取回答
const response = await Promise.allSettled(
[0.2, 0.8].map(
(temperature) =>
chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: temperature,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})) // 从 content 中提取 QA
)
);
// 过滤出成功的响应
const successResponse: {
rawContent: string;
result: { q: string; a: string }[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const rawContents = successResponse.map((item) => item.rawContent);
const results = successResponse.map((item) => item.result).flat();
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 0,
$push: {
rawResponse: {
$each: successResponse.map((item) => item.rawContent)
const response = await chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: 0.2,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: text
}
]
},
result: {
$each: results
{
timeout: 120000,
httpsAgent
}
}
});
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})); // 从 content 中提取 QA
// 插入 modelData 表,生成向量
await ModelData.insertMany(
response.result.map((item) => ({
modelId: dataItem.modelId,
userId: dataItem.userId,
text: item.a,
q: [
{
id: nanoid(),
text: item.q
}
],
status: 1
}))
);
console.log(
'生成QA成功time:',
`${(Date.now() - startTime) / 1000}s`,
'QA数量',
results.length
response.result.length
);
// 计费
pushSplitDataBill({
isPay: !userApiKey && results.length > 0,
isPay: !userApiKey && response.result.length > 0,
userId: dataItem.userId,
type: 'QA',
text: systemPrompt.content + dataItem.text + rawContents.join('')
text: systemPrompt.content + text + response.rawContent
});
} catch (error: any) {
console.log('error: 生成QA错误', dataItem?._id);
console.log('response:', error?.response);
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
$inc: {
// 剩余尝试次数-1
times: -1
}
});
}
}
generateQA(true);
generateQA(true);
generateVector(true);
} catch (error: any) {
console.log(error);
console.log('生成QA错误:', error?.response);
setTimeout(() => {
generateQA(true);
}, 10000);
}
}
/**

View File

@@ -0,0 +1,88 @@
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis';
export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return;
global.generatingVector = true;
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await ModelData.findOne({
status: { $ne: 0 }
});
if (!dataItem) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
// 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
// 生成词向量
const response = await Promise.allSettled(
dataItem.q.map((item, i) =>
chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: item.text
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'JSON.SET',
`${VecModelDataIndex}:${dataId}:${i}`,
'$',
JSON.stringify({
dataId,
modelId: String(dataItem.modelId),
vector
})
])
)
)
);
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
throw new Error(JSON.stringify(response));
}
// 修改该数据状态
await ModelData.findByIdAndUpdate(dataItem._id, {
status: 0
});
console.log(`生成向量成功: ${dataItem._id}`);
setTimeout(() => {
generateVector(true);
}, 3000);
} catch (error: any) {
console.log(error);
console.log('error: 生成向量错误', error?.response?.data);
if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);
}, 60000);
}
}
}

View File

@@ -34,7 +34,7 @@ export const pushChatBill = async ({
// 计算价格
const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens.length;
console.log(`chat bill, price: ${formatPrice(price)}`);
console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}`);
try {
// 插入 Bill 记录

View File

@@ -0,0 +1,31 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
rawText: {
type: String,
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@@ -1,6 +1,7 @@
import mongoose from 'mongoose';
import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract';
import { generateVector } from './events/generateVector';
/**
* 连接 MongoDB 数据库
@@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
}
generateQA();
generateAbstract();
// generateAbstract();
generateVector();
}
export * from './models/authCode';
@@ -40,3 +42,4 @@ export * from './models/bill';
export * from './models/pay';
export * from './models/data';
export * from './models/dataItem';
export * from './models/splitData';

View File

@@ -1,5 +1,4 @@
import { createClient, SchemaFieldTypes } from 'redis';
import { ModelDataIndex } from '@/constants/redis';
import { createClient } from 'redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
@@ -14,7 +13,7 @@ export const connectRedis = async () => {
try {
global.redisClient = createClient({
url: 'redis://default:121914yu@120.76.193.200:8100'
url: process.env.REDIS_URL
});
global.redisClient.on('error', (err) => {
@@ -33,46 +32,6 @@ export const connectRedis = async () => {
// 0 - 测试库1 - 正式
await global.redisClient.select(0);
// 创建索引
try {
await global.redisClient.ft.create(
ModelDataIndex,
{
// '$.vector': SchemaFieldTypes.VECTOR,
'$.modelId': {
type: SchemaFieldTypes.TEXT,
AS: 'modelId'
},
'$.userId': {
type: SchemaFieldTypes.TEXT,
AS: 'userId'
},
'$.status': {
type: SchemaFieldTypes.NUMERIC,
AS: 'status'
}
},
{
ON: 'JSON',
PREFIX: 'model:data'
}
);
} catch (error) {
console.log('创建索引失败', error);
}
// await global.redisClient.json.set('fastgpt:modeldata:2', '$', {
// vector: [124, 214, 412, 4, 124, 1, 4, 1, 4, 3, 423],
// modelId: 'daf',
// userId: 'adfd',
// q: 'fasf',
// a: 'afasf',
// status: 0,
// createTime: new Date()
// });
// const value = await global.redisClient.json.get('fastgpt:modeldata:2');
// console.log(value);
return global.redisClient;
} catch (error) {
console.log(error, '==');

View File

@@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
return systemPrompt ? [systemPrompt, ...res] : res;
};
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
if (tokens >= maxTokens) {
break;
}
}
return splitText;
};

View File

@@ -6,6 +6,7 @@ declare global {
var redisClient: RedisClientType | null;
var generatingQA: boolean;
var generatingAbstract: boolean;
var generatingVector: boolean;
var QRCode: any;
interface Window {
['pdfjs-dist/build/pdf']: any;

View File

@@ -64,6 +64,15 @@ export interface ModelDataSchema {
status: ModelDataType;
}
export interface ModelSplitDataSchema {
_id: string;
userId: string;
modelId: string;
rawText: string;
errorText: string;
textList: string[];
}
export interface TrainingSchema {
_id: string;
serviceName: ServiceName;

10
src/types/redis.d.ts vendored
View File

@@ -1,10 +1,6 @@
export interface RedisModelDataItemType {
id: string;
value: {
vector: number[];
q: string; // 提问词
a: string; // 原文
modelId: string;
userId: string;
};
vector: number[];
dataId: string;
modelId: string;
}

View File

@@ -124,3 +124,15 @@ export const readDocContent = (file: File) =>
reject('读取 doc 文件失败');
};
});
export const vectorToBuffer = (vector: number[]) => {
const float32Arr = new Float32Array(vector);
const myBuffer = new ArrayBuffer(float32Arr.length * Float32Array.BYTES_PER_ELEMENT);
const myView = new DataView(myBuffer);
for (let i = 0; i < float32Arr.length; i++) {
myView.setFloat32(i * Float32Array.BYTES_PER_ELEMENT, float32Arr[i], true);
}
return Buffer.from(myBuffer);
};