This commit is contained in:
archer
2023-03-30 22:33:58 +08:00
36 changed files with 1187 additions and 232 deletions

View File

@@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210
MONGODB_URI= MONGODB_URI=
MY_MAIL= MY_MAIL=
MAILE_CODE= MAILE_CODE=
TOKEN_KEY= TOKEN_KEY=
OPENAIKEY=
REDIS_URL=

View File

@@ -44,8 +44,8 @@ export const postModelDataInput = (data: {
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data); }) => POST(`/model/data/pushModelDataInput`, data);
export const postModelDataSelect = (modelId: string, dataIds: string[]) => export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/pushModelDataSelectData`, { modelId, dataIds }); POST(`/model/data/splitData`, { modelId, text });
export const putModelDataById = (data: { dataId: string; text: string }) => export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data); PUT('/model/data/putModelData', data);

View File

@@ -26,12 +26,12 @@ const navbarList = [
link: '/model/list', link: '/model/list',
activeLink: ['/model/list', '/model/detail'] activeLink: ['/model/list', '/model/detail']
}, },
{ // {
label: '数据', // label: '数据',
icon: 'icon-datafull', // icon: 'icon-datafull',
link: '/data/list', // link: '/data/list',
activeLink: ['/data/list', '/data/detail'] // activeLink: ['/data/list', '/data/detail']
}, // },
{ {
label: '账号', label: '账号',
icon: 'icon-yonghu-yuan', icon: 'icon-yonghu-yuan',

View File

@@ -2,9 +2,16 @@ import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchem
export enum ChatModelNameEnum { export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo', GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003' GPT3 = 'text-davinci-003'
} }
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
};
export type ModelConstantsData = { export type ModelConstantsData = {
serviceCompany: `${ServiceName}`; serviceCompany: `${ServiceName}`;
name: string; name: string;
@@ -28,6 +35,17 @@ export const modelList: ModelConstantsData[] = [
trainedMaxToken: 2000, trainedMaxToken: 2000,
maxTemperature: 2, maxTemperature: 2,
price: 3 price: 3
},
{
serviceCompany: 'openai',
name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT,
trainName: 'vector',
maxToken: 4000,
contextMaxToken: 7500,
trainedMaxToken: 2000,
maxTemperature: 1,
price: 3
} }
// { // {
// serviceCompany: 'openai', // serviceCompany: 'openai',

View File

@@ -1,2 +1 @@
export const ModelDataIndex = 'model:data'; export const VecModelDataIndex = 'model:data';
export const VecModelDataIndex = 'vec:model:data';

View File

@@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const model: ModelSchema = chat.modelId; const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型异常,请用 chatgpt 模型'); throw new Error('模型加载异常');
} }
// 读取对话内容 // 读取对话内容

View File

@@ -0,0 +1,241 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
const redis = await connectRedis();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...chat.content, prompt];
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 把输入的内容转成向量
const promptVector = await chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: prompt.value
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
const binary = vectorToBuffer(promptVector);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataIndex}`,
`@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'dataId',
// 'SORTBY',
// 'score',
'PARAMS',
'2',
'blob',
binary,
'DIALECT',
'2'
]);
// 格式化响应值获取去重后的id
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return redisData[i][1];
})
.filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text')
.then((res) => res?.text || '');
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 2800);
prompts.unshift({
obj: 'SYSTEM',
value: `请根据下面的知识回答问题: ${systemPrompt}`
});
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
let startTime = Date.now();
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
},
{
timeout: 40000,
responseType: 'stream',
httpsAgent
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = '';
stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
// console.log('content:', content)
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!DataRecord) { if (!DataRecord) {
throw new Error('找不到数据集'); throw new Error('找不到数据集');
} }
const replaceText = text.replace(/[\r\n\\n]+/g, ' '); const replaceText = text.replace(/[\\n]+/g, ' ');
// 文本拆分成 chunk // 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
@@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
splitText += chunk; splitText += chunk;
const tokens = encode(splitText).length; const tokens = encode(splitText).length;
if (tokens >= 980) { if (tokens >= 780) {
dataItems.push({ dataItems.push({
userId, userId,
dataId, dataId,

View File

@@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model'; import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { Model } from '@/service/models/model'; import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
// 重名校验
const authRepeatName = await Model.findOne({
name,
userId
});
if (authRepeatName) {
throw new Error('模型名重复');
}
// 上限校验 // 上限校验
const authCount = await Model.countDocuments({ const authCount = await Model.countDocuments({
userId userId
@@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running, status: ModelStatusEnum.running,
service: { service: {
company: modelItem.serviceCompany, company: modelItem.serviceCompany,
trainId: modelItem.trainName, trainId: '',
chatModel: modelItem.model, chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
} }
}); });

View File

@@ -3,6 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema'; import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@@ -44,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
})) }))
); );
generateVector(true);
jsonRes(res, { jsonRes(res, {
data: model data: model
}); });

View File

@@ -0,0 +1,67 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
textList.push(splitText);
splitText = '';
}
});
// 批量插入数据
await SplitData.create({
userId,
modelId,
rawText: text,
textList
});
// generateQA();
jsonRes(res, {
data: { chunks, replaceText }
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo'; import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model'; import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
@@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
// 删除模型 let requestQueue: any[] = [];
await Model.deleteOne({
_id: modelId,
userId
});
// 删除对应的聊天 // 删除对应的聊天
await Chat.deleteMany({ requestQueue.push(
modelId Chat.deleteMany({
}); modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练 // 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({ const training: TrainingItemType | null = await Training.findOne({
@@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
} }
// 删除对应训练记录 // 删除对应训练记录
await Training.deleteMany({ requestQueue.push(
modelId Training.deleteMany({
}); modelId
})
);
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await requestQueue;
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {

View File

@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt, systemPrompt,
intro, intro,
temperature, temperature,
service, // service,
security security
} }
); );

View File

@@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
async (prompts: ChatSiteItemType) => { async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = { const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3' [ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
}; };

View File

@@ -184,7 +184,7 @@ const DataList = () => {
> >
</Button> </Button>
<Menu> {/* <Menu>
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}> <MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
导出 导出
</MenuButton> </MenuButton>
@@ -200,7 +200,7 @@ const DataList = () => {
</MenuItem> </MenuItem>
)} )}
</MenuList> </MenuList>
</Menu> </Menu> */}
<Button <Button
size={'sm'} size={'sm'}

View File

@@ -0,0 +1,141 @@
import React, { useState, useCallback } from 'react';
import {
Box,
IconButton,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
Input,
Textarea
} from '@chakra-ui/react';
import { useForm, useFieldArray } from 'react-hook-form';
import { postModelDataInput } from '@/api/model';
import { useToast } from '@/hooks/useToast';
import { DeleteIcon } from '@chakra-ui/icons';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
const InputDataModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [importing, setImporting] = useState(false);
const { toast } = useToast();
const { register, handleSubmit, control } = useForm<FormData>({
defaultValues: {
text: '',
q: [{ val: '' }]
}
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
});
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
await postModelDataInput({
modelId: modelId,
data: [
{
text: e.text,
q: e.q.map((item) => ({
id: nanoid(),
text: item.val
}))
}
]
});
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onClose();
onSuccess();
} catch (err) {
console.log(err);
}
setImporting(false);
},
[modelId, onClose, onSuccess, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
</Box>
<Flex px={6} pt={2} pb={4}>
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
</Button>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
</Button>
</Flex>
</ModalContent>
</Modal>
);
};
export default InputDataModal;

View File

@@ -0,0 +1,202 @@
import React, { useCallback } from 'react';
import {
Box,
TableContainer,
Table,
Thead,
Tbody,
Tr,
Th,
Td,
IconButton,
Flex,
Button,
useDisclosure,
Textarea,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema';
import { ModelDataStatusMap } from '@/constants/model';
import { usePaging } from '@/hooks/usePaging';
import ScrollData from '@/components/ScrollData';
import { getModelDataList, delOneModelData, putModelDataById } from '@/api/model';
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast();
const { Loading } = useLoading();
const {
nextPage,
isLoadAll,
requesting,
data: modelDataList,
total,
setData,
getData
} = usePaging<ModelDataSchema>({
api: getModelDataList,
pageSize: 20,
params: {
modelId: model._id
}
});
const updateAnswer = useCallback(
async (dataId: string, text: string) => {
await putModelDataById({
dataId,
text
});
toast({
title: '修改回答成功',
status: 'success'
});
},
[toast]
);
const {
isOpen: isOpenInputModal,
onOpen: onOpenInputModal,
onClose: onCloseInputModal
} = useDisclosure();
const {
isOpen: isOpenSelectModal,
onOpen: onOpenSelectModal,
onClose: onCloseSelectModal
} = useDisclosure();
return (
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
: {total}{' '}
<Box as={'span'} fontSize={'sm'}>
</Box>
</Box>
<IconButton
icon={<RepeatIcon />}
aria-label={'refresh'}
variant={'outline'}
mr={4}
onClick={() => getData(1, true)}
/>
<Menu>
<MenuButton as={Button}></MenuButton>
<MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem>
</MenuList>
</Menu>
</Flex>
<ScrollData
h={'100%'}
px={6}
mt={3}
isLoadAll={isLoadAll}
requesting={requesting}
nextPage={nextPage}
position={'relative'}
>
<TableContainer mt={4}>
<Table variant={'simple'}>
<Thead>
<Tr>
<Th>Question</Th>
<Th>Text</Th>
<Th>Status</Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item._id}>
<Td w={'350px'}>
{item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}:{' '}
<Box as={'span'} userSelect={'all'}>
{item.text}
</Box>
</Box>
))}
</Td>
<Td minW={'200px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
setData((state) =>
state.map((data) => ({
...data,
text: data._id === item._id ? e.target.value : data.text
}))
);
}
}}
></Textarea>
</Td>
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
<Td>
<IconButton
icon={<DeleteIcon />}
variant={'outline'}
colorScheme={'gray'}
aria-label={'delete'}
size={'sm'}
onClick={async () => {
delOneModelData(item._id);
setData((state) => state.filter((data) => data._id !== item._id));
}}
/>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
<Loading loading={requesting} fixed={false} />
</ScrollData>
{isOpenInputModal && (
<InputModel
modelId={model._id}
onClose={onCloseInputModal}
onSuccess={() => getData(1, true)}
/>
)}
{isOpenSelectModal && (
<SelectModel
modelId={model._id}
onClose={onCloseSelectModal}
onSuccess={() => getData(1, true)}
/>
)}
</>
);
};
export default ModelDataCard;

View File

@@ -23,9 +23,11 @@ import { useConfirm } from '@/hooks/useConfirm';
const ModelEditForm = ({ const ModelEditForm = ({
formHooks, formHooks,
canTrain,
handleDelModel handleDelModel
}: { }: {
formHooks: UseFormReturn<ModelSchema>; formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
handleDelModel: () => void; handleDelModel: () => void;
}) => { }) => {
const { openConfirm, ConfirmChild } = useConfirm({ const { openConfirm, ConfirmChild } = useConfirm({
@@ -136,15 +138,24 @@ const ModelEditForm = ({
</Flex> </Flex>
</FormControl> </FormControl>
<Box mt={4}> <Box mt={4}>
<Box mb={1}></Box> {canTrain ? (
<Textarea <Box fontWeight={'bold'}>
rows={6} prompt
maxLength={-1} 使 tokens
{...register('systemPrompt')} </Box>
placeholder={ ) : (
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向' <>
} <Box mb={1}></Box>
/> <Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</>
)}
</Box> </Box>
</Card> </Card>
{/* <Card p={4}> {/* <Card p={4}>

View File

@@ -0,0 +1,155 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { customAlphabet } from 'nanoid';
import { encode } from 'gpt-token-utils';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataFileText } from '@/api/model';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
const SelectFileModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [fileText, setFileText] = useState('');
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const fileTexts = (
await Promise.all(
e.map((file) => {
// @ts-ignore
const extension = file?.name?.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
default:
return '';
}
})
)
)
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
title: typeof error === 'string' ? error : '解析文件失败',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileText) return;
await postModelDataFileText(modelId, fileText);
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</Flex>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectFileModal;

View File

@@ -1,12 +1,6 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
getModelById,
delModelById,
postTrainModel,
putModelTrainingStatus,
putModelById
} from '@/api/model';
import { getChatSiteId } from '@/api/chat'; import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
@@ -16,12 +10,11 @@ import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/c
import { useGlobalStore } from '@/store/global'; import { useGlobalStore } from '@/store/global';
import { useScreen } from '@/hooks/useScreen'; import { useScreen } from '@/hooks/useScreen';
import ModelEditForm from './components/ModelEditForm'; import ModelEditForm from './components/ModelEditForm';
// import Icon from '@/components/Iconfont';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
// import dynamic from 'next/dynamic'; // import dynamic from 'next/dynamic';
import ModelDataCard from './components/ModelDataCard'; import ModelDataCard from './components/ModelDataCard';
// const Training = dynamic(() => import('./components/Training')); const ModelDataCard = dynamic(() => import('./components/ModelDataCard'));
const ModelDetail = ({ modelId }: { modelId: string }) => { const ModelDetail = ({ modelId }: { modelId: string }) => {
const { toast } = useToast(); const { toast } = useToast();
@@ -29,16 +22,16 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
const { isPc, media } = useScreen(); const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore(); const { setLoading } = useGlobalStore();
const SelectFileDom = useRef<HTMLInputElement>(null); // const SelectFileDom = useRef<HTMLInputElement>(null);
const [model, setModel] = useState<ModelSchema>(defaultModel); const [model, setModel] = useState<ModelSchema>(defaultModel);
const formHooks = useForm<ModelSchema>({ const formHooks = useForm<ModelSchema>({
defaultValues: model defaultValues: model
}); });
// const canTrain = useMemo(() => { const canTrain = useMemo(() => {
// const openai = modelList.find((item) => item.model === model?.service.modelName); const openai = modelList.find((item) => item.model === model?.service.modelName);
// return openai && openai.trainName; return !!(openai && openai.trainName);
// }, [model]); }, [model]);
/* 加载模型数据 */ /* 加载模型数据 */
const loadModel = useCallback(async () => { const loadModel = useCallback(async () => {
@@ -89,34 +82,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
}, [setLoading, model, router]); }, [setLoading, model, router]);
/* 上传数据集,触发微调 */ /* 上传数据集,触发微调 */
const startTraining = useCallback( // const startTraining = useCallback(
async (e: React.ChangeEvent<HTMLInputElement>) => { // async (e: React.ChangeEvent<HTMLInputElement>) => {
if (!modelId || !e.target.files || e.target.files?.length === 0) return; // if (!modelId || !e.target.files || e.target.files?.length === 0) return;
setLoading(true); // setLoading(true);
try { // try {
const file = e.target.files[0]; // const file = e.target.files[0];
const formData = new FormData(); // const formData = new FormData();
formData.append('file', file); // formData.append('file', file);
await postTrainModel(modelId, formData); // await postTrainModel(modelId, formData);
toast({ // toast({
title: '开始训练...', // title: '开始训练...',
status: 'success' // status: 'success'
}); // });
// 重新获取模型 // // 重新获取模型
loadModel(); // loadModel();
} catch (err: any) { // } catch (err: any) {
toast({ // toast({
title: err?.message || '上传文件失败', // title: err?.message || '上传文件失败',
status: 'error' // status: 'error'
}); // });
console.log('error->', err); // console.log('error->', err);
} // }
setLoading(false); // setLoading(false);
}, // },
[setLoading, loadModel, modelId, toast] // [setLoading, loadModel, modelId, toast]
); // );
/* 点击更新模型状态 */ /* 点击更新模型状态 */
const handleClickUpdateStatus = useCallback(async () => { const handleClickUpdateStatus = useCallback(async () => {
@@ -248,22 +241,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)} )}
</Card> </Card>
<Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}> <Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}>
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} /> <ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} canTrain={canTrain} />
{/* {canTrain && ( {/* {canTrain && (
<Card p={4}> <Card p={4}>
<Training model={model} /> <Training model={model} />
</Card> </Card>
)} */} )} */}
<Card p={4} height={'500px'} gridColumnStart={1} gridColumnEnd={3}> {canTrain && model._id && (
{model._id && <ModelDataCard model={model} />} <Card
</Card> p={4}
height={'700px'}
{...media(
{
gridColumnStart: 1,
gridColumnEnd: 3
},
{}
)}
>
<ModelDataCard model={model} />
</Card>
)}
</Grid> </Grid>
{/* 文件选择 */} {/* 文件选择 */}
<Box position={'absolute'} w={0} h={0} overflow={'hidden'}> {/* <Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
<input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} /> <input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} />
</Box> </Box> */}
</> </>
); );
}; };

View File

@@ -1,29 +1,26 @@
import { DataItem } from '@/service/mongo'; import { SplitData, ModelData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools'; import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai'; import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model'; import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export async function generateQA(next = false): Promise<any> { export async function generateQA(next = false): Promise<any> {
if (process.env.NODE_ENV === 'development') return;
if (global.generatingQA && !next) return; if (global.generatingQA && !next) return;
global.generatingQA = true; global.generatingQA = true;
const systemPrompt: ChatCompletionRequestMessage = { const systemPrompt: ChatCompletionRequestMessage = {
role: 'system', role: 'system',
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n` content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
}; };
let dataItem: DataItemSchema | null = null;
try { try {
// 找出一个需要生成的 dataItem // 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({ const dataItem = await SplitData.findOne({
status: { $ne: 0 }, textList: { $exists: true, $ne: [] }
times: { $gt: 0 },
type: 'QA'
}); });
if (!dataItem) { if (!dataItem) {
@@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
return; return;
} }
// 更新状态为生成中 // 弹出文本
await DataItem.findByIdAndUpdate(dataItem._id, { await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
status: 2
}); const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) {
throw new Error('无文本');
}
// 获取 openapi Key // 获取 openapi Key
let userApiKey, systemKey; let userApiKey, systemKey;
@@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
userApiKey = key.userApiKey; userApiKey = key.userApiKey;
systemKey = key.systemKey; systemKey = key.systemKey;
} catch (error) { } catch (error) {
// 余额不够了, 把用户所有记录改成闲置 // 余额不够了, 清空该记录
await DataItem.updateMany({ await SplitData.findByIdAndUpdate(dataItem._id, {
userId: dataItem.userId, textList: [],
status: 0 errorText: '余额不足,生成数据集任务终止'
}); });
throw new Error('获取 openai key 失败'); throw new Error('获取 openai key 失败');
} }
@@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openai 请求实例 // 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey); const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
const response = await Promise.allSettled( const response = await chatAPI
[0.2, 0.8].map( .createChatCompletion(
(temperature) => {
chatAPI model: ChatModelNameEnum.GPT35,
.createChatCompletion( temperature: 0.2,
{ n: 1,
model: ChatModelNameEnum.GPT35, messages: [
temperature: temperature, systemPrompt,
n: 1, {
messages: [ role: 'user',
systemPrompt, content: text
{ }
role: 'user', ]
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})) // 从 content 中提取 QA
)
);
// 过滤出成功的响应
const successResponse: {
rawContent: string;
result: { q: string; a: string }[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const rawContents = successResponse.map((item) => item.rawContent);
const results = successResponse.map((item) => item.result).flat();
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 0,
$push: {
rawResponse: {
$each: successResponse.map((item) => item.rawContent)
}, },
result: { {
$each: results timeout: 120000,
httpsAgent
} }
} )
}); .then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})); // 从 content 中提取 QA
// 插入 modelData 表,生成向量
await ModelData.insertMany(
response.result.map((item) => ({
modelId: dataItem.modelId,
userId: dataItem.userId,
text: item.a,
q: [
{
id: nanoid(),
text: item.q
}
],
status: 1
}))
);
console.log( console.log(
'生成QA成功time:', '生成QA成功time:',
`${(Date.now() - startTime) / 1000}s`, `${(Date.now() - startTime) / 1000}s`,
'QA数量', 'QA数量',
results.length response.result.length
); );
// 计费 // 计费
pushSplitDataBill({ pushSplitDataBill({
isPay: !userApiKey && results.length > 0, isPay: !userApiKey && response.result.length > 0,
userId: dataItem.userId, userId: dataItem.userId,
type: 'QA', type: 'QA',
text: systemPrompt.content + dataItem.text + rawContents.join('') text: systemPrompt.content + text + response.rawContent
}); });
} catch (error: any) {
console.log('error: 生成QA错误', dataItem?._id);
console.log('response:', error?.response);
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
$inc: {
// 剩余尝试次数-1
times: -1
}
});
}
}
generateQA(true); generateQA(true);
generateVector(true);
} catch (error: any) {
console.log(error);
console.log('生成QA错误:', error?.response);
setTimeout(() => {
generateQA(true);
}, 10000);
}
} }
/** /**

View File

@@ -0,0 +1,88 @@
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis';
export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return;
global.generatingVector = true;
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await ModelData.findOne({
status: { $ne: 0 }
});
if (!dataItem) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
// 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
// 生成词向量
const response = await Promise.allSettled(
dataItem.q.map((item, i) =>
chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: item.text
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'JSON.SET',
`${VecModelDataIndex}:${dataId}:${i}`,
'$',
JSON.stringify({
dataId,
modelId: String(dataItem.modelId),
vector
})
])
)
)
);
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
throw new Error(JSON.stringify(response));
}
// 修改该数据状态
await ModelData.findByIdAndUpdate(dataItem._id, {
status: 0
});
console.log(`生成向量成功: ${dataItem._id}`);
setTimeout(() => {
generateVector(true);
}, 3000);
} catch (error: any) {
console.log(error);
console.log('error: 生成向量错误', error?.response?.data);
if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);
}, 60000);
}
}
}

View File

@@ -34,7 +34,7 @@ export const pushChatBill = async ({
// 计算价格 // 计算价格
const unitPrice = modelItem?.price || 5; const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens.length; const price = unitPrice * tokens.length;
console.log(`chat bill, price: ${formatPrice(price)}`); console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}`);
try { try {
// 插入 Bill 记录 // 插入 Bill 记录

View File

@@ -0,0 +1,31 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
rawText: {
type: String,
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@@ -1,6 +1,7 @@
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { generateQA } from './events/generateQA'; import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract'; import { generateAbstract } from './events/generateAbstract';
import { generateVector } from './events/generateVector';
/** /**
* 连接 MongoDB 数据库 * 连接 MongoDB 数据库
@@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
} }
generateQA(); generateQA();
generateAbstract(); // generateAbstract();
generateVector();
} }
export * from './models/authCode'; export * from './models/authCode';
@@ -40,3 +42,4 @@ export * from './models/bill';
export * from './models/pay'; export * from './models/pay';
export * from './models/data'; export * from './models/data';
export * from './models/dataItem'; export * from './models/dataItem';
export * from './models/splitData';

View File

@@ -1,5 +1,4 @@
import { createClient, SchemaFieldTypes } from 'redis'; import { createClient } from 'redis';
import { ModelDataIndex } from '@/constants/redis';
import { customAlphabet } from 'nanoid'; import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10); const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
@@ -14,7 +13,7 @@ export const connectRedis = async () => {
try { try {
global.redisClient = createClient({ global.redisClient = createClient({
url: 'redis://default:121914yu@120.76.193.200:8100' url: process.env.REDIS_URL
}); });
global.redisClient.on('error', (err) => { global.redisClient.on('error', (err) => {
@@ -33,46 +32,6 @@ export const connectRedis = async () => {
// 0 - 测试库1 - 正式 // 0 - 测试库1 - 正式
await global.redisClient.select(0); await global.redisClient.select(0);
// 创建索引
try {
await global.redisClient.ft.create(
ModelDataIndex,
{
// '$.vector': SchemaFieldTypes.VECTOR,
'$.modelId': {
type: SchemaFieldTypes.TEXT,
AS: 'modelId'
},
'$.userId': {
type: SchemaFieldTypes.TEXT,
AS: 'userId'
},
'$.status': {
type: SchemaFieldTypes.NUMERIC,
AS: 'status'
}
},
{
ON: 'JSON',
PREFIX: 'model:data'
}
);
} catch (error) {
console.log('创建索引失败', error);
}
// await global.redisClient.json.set('fastgpt:modeldata:2', '$', {
// vector: [124, 214, 412, 4, 124, 1, 4, 1, 4, 3, 423],
// modelId: 'daf',
// userId: 'adfd',
// q: 'fasf',
// a: 'afasf',
// status: 0,
// createTime: new Date()
// });
// const value = await global.redisClient.json.get('fastgpt:modeldata:2');
// console.log(value);
return global.redisClient; return global.redisClient;
} catch (error) { } catch (error) {
console.log(error, '=='); console.log(error, '==');

View File

@@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
return systemPrompt ? [systemPrompt, ...res] : res; return systemPrompt ? [systemPrompt, ...res] : res;
}; };
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
if (tokens >= maxTokens) {
break;
}
}
return splitText;
};

View File

@@ -6,6 +6,7 @@ declare global {
var redisClient: RedisClientType | null; var redisClient: RedisClientType | null;
var generatingQA: boolean; var generatingQA: boolean;
var generatingAbstract: boolean; var generatingAbstract: boolean;
var generatingVector: boolean;
var QRCode: any; var QRCode: any;
interface Window { interface Window {
['pdfjs-dist/build/pdf']: any; ['pdfjs-dist/build/pdf']: any;

View File

@@ -64,6 +64,15 @@ export interface ModelDataSchema {
status: ModelDataType; status: ModelDataType;
} }
export interface ModelSplitDataSchema {
_id: string;
userId: string;
modelId: string;
rawText: string;
errorText: string;
textList: string[];
}
export interface TrainingSchema { export interface TrainingSchema {
_id: string; _id: string;
serviceName: ServiceName; serviceName: ServiceName;

10
src/types/redis.d.ts vendored
View File

@@ -1,10 +1,6 @@
export interface RedisModelDataItemType { export interface RedisModelDataItemType {
id: string; id: string;
value: { vector: number[];
vector: number[]; dataId: string;
q: string; // 提问词 modelId: string;
a: string; // 原文
modelId: string;
userId: string;
};
} }

View File

@@ -124,3 +124,15 @@ export const readDocContent = (file: File) =>
reject('读取 doc 文件失败'); reject('读取 doc 文件失败');
}; };
}); });
export const vectorToBuffer = (vector: number[]) => {
const float32Arr = new Float32Array(vector);
const myBuffer = new ArrayBuffer(float32Arr.length * Float32Array.BYTES_PER_ELEMENT);
const myView = new DataView(myBuffer);
for (let i = 0; i < float32Arr.length; i++) {
myView.setFloat32(i * Float32Array.BYTES_PER_ELEMENT, float32Arr[i], true);
}
return Buffer.from(myBuffer);
};