feat: 摘要拆分

This commit is contained in:
archer
2023-03-26 22:09:59 +08:00
parent 888642f154
commit 3e4487ad9a
20 changed files with 397 additions and 83 deletions

View File

@@ -3,12 +3,13 @@ import { RequestPaging } from '../types/index';
import { Obj2Query } from '@/utils/tools';
import type { DataListItem } from '@/types/data';
import type { PagingData } from '../types/index';
import { DataItemSchema } from '@/types/mongoSchema';
import type { DataItemSchema } from '@/types/mongoSchema';
import type { CreateDataProps } from '@/pages/data/components/CreateDataModal';
export const getDataList = (data: RequestPaging) =>
GET<PagingData<DataListItem>>(`/data/getDataList?${Obj2Query(data)}`);
export const postData = (name: string) => POST<string>(`/data/postData?name=${name}`);
export const postData = (data: CreateDataProps) => POST<string>(`/data/postData`, data);
export const postSplitData = (dataId: string, text: string) =>
POST(`/data/splitData`, { dataId, text });

6
src/constants/data.ts Normal file
View File

@@ -0,0 +1,6 @@
import type { DataType } from '@/types/data';
export const DataTypeTextMap: Record<DataType, string> = {
QA: '问答拆分',
abstract: '摘要总结'
};

View File

@@ -1,6 +1,8 @@
export enum BillTypeEnum {
chat = 'chat',
splitData = 'splitData',
QA = 'QA',
abstract = 'abstract',
return = 'return'
}
export enum PageTypeEnum {
@@ -11,6 +13,8 @@ export enum PageTypeEnum {
export const BillTypeMap: Record<`${BillTypeEnum}`, string> = {
[BillTypeEnum.chat]: '对话',
[BillTypeEnum.splitData]: '文本拆分',
[BillTypeEnum.splitData]: 'QA拆分',
[BillTypeEnum.QA]: 'QA拆分',
[BillTypeEnum.abstract]: '摘要总结',
[BillTypeEnum.return]: '退款'
};

View File

@@ -2,11 +2,12 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Data } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import type { DataType } from '@/types/data';
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { name } = req.query as { name: string };
if (!name) {
let { name, type } = req.body as { name: string; type: DataType };
if (!name || !type) {
throw new Error('参数错误');
}
await connectToDatabase();
@@ -18,7 +19,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 生成 data 集合
const data = await Data.create({
userId,
name
name,
type
});
jsonRes(res, {

View File

@@ -1,9 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Data, DataItem } from '@/service/mongo';
import { connectToDatabase, DataItem, Data } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { generateAbstract } from '@/service/events/generateAbstract';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { text, dataId } = req.body as { text: string; dataId: string };
@@ -17,14 +19,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const userId = await authToken(authorization);
const DataRecord = await Data.findById(dataId);
if (!DataRecord) {
throw new Error('找不到数据集');
}
const dataItems: any[] = [];
// 格式化文本长度
// 每 1000 字符一组
for (let i = 0; i <= text.length / 1000; i++) {
dataItems.push({
temperature: 0,
userId,
dataId,
type: DataRecord.type,
text: text.slice(i * 1000, (i + 1) * 1000),
status: 1
});
@@ -33,10 +41,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 批量插入数据
await DataItem.insertMany(dataItems);
generateQA();
try {
generateQA();
generateAbstract();
} catch (error) {
error;
}
jsonRes(res, {
data: dataItems.length
data: ''
});
} catch (err) {
jsonRes(res, {

View File

@@ -13,14 +13,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// await DataItem.updateMany(
// {},
// {
// times: 2
// type: 'QA'
// // times: 2
// }
// );
await Data.updateMany(
{},
{
isDeleted: false
type: 'QA'
}
);

View File

@@ -8,10 +8,21 @@ import {
ModalBody,
ModalCloseButton,
Button,
Input
Input,
Select,
FormControl,
FormErrorMessage
} from '@chakra-ui/react';
import { postData } from '@/api/data';
import { useMutation } from '@tanstack/react-query';
import { useForm, SubmitHandler } from 'react-hook-form';
import { DataType } from '@/types/data';
import { DataTypeTextMap } from '@/constants/data';
export interface CreateDataProps {
name: string;
type: DataType;
}
const CreateDataModal = ({
onClose,
@@ -21,9 +32,20 @@ const CreateDataModal = ({
onSuccess: () => void;
}) => {
const [inputVal, setInputVal] = useState('');
const {
getValues,
register,
handleSubmit,
formState: { errors }
} = useForm<CreateDataProps>({
defaultValues: {
name: '',
type: 'abstract'
}
});
const { isLoading, mutate } = useMutation({
mutationFn: (name: string) => postData(name),
mutationFn: (e: CreateDataProps) => postData(e),
onSuccess() {
onSuccess();
onClose();
@@ -37,23 +59,33 @@ const CreateDataModal = ({
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody display={'flex'}>
<Input
value={inputVal}
onChange={(e) => setInputVal(e.target.value)}
placeholder={'数据集名称'}
></Input>
<ModalBody>
<FormControl mb={8} isInvalid={!!errors.name}>
<Input
placeholder="数据集名称"
{...register('name', {
required: '数据集名称不能为空'
})}
/>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.name && errors.name.message}
</FormErrorMessage>
</FormControl>
<FormControl>
<Select placeholder="数据集类型" {...register('type', {})}>
{Object.entries(DataTypeTextMap).map(([key, value]) => (
<option key={key} value={key}>
{value}
</option>
))}
</Select>
</FormControl>
</ModalBody>
<ModalFooter>
<Button colorScheme={'gray'} onClick={onClose}>
</Button>
<Button
ml={3}
isDisabled={inputVal === ''}
isLoading={isLoading}
onClick={() => mutate(inputVal)}
>
<Button ml={3} isLoading={isLoading} onClick={handleSubmit(mutate as any)}>
</Button>
</ModalFooter>

View File

@@ -22,6 +22,7 @@ import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import { formatPrice } from '@/utils/user';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode, decode } from 'gpt-token-utils';
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
@@ -106,6 +107,7 @@ const ImportDataModal = ({
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
@@ -161,7 +163,9 @@ const ImportDataModal = ({
placeholder={'请粘贴或输入需要处理的文本'}
onChange={(e) => setTextInput(e.target.value)}
/>
<Box mt={2}> {textInput.length} </Box>
<Box mt={2}>
{textInput.length} {encode(textInput).length} tokens
</Box>
</>
)}
{activeTab === 'doc' && (
@@ -174,12 +178,15 @@ const ImportDataModal = ({
border={'1px solid '}
borderColor={'blackAlpha.200'}
borderRadius={'md'}
fontSize={'sm'}
>
<Button onClick={onOpen}></Button>
<Box mt={2}> {fileExtension} </Box>
{fileText && (
<>
<Box mt={2}> {fileText.length} </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
maxH={'300px'}
w={'100%'}

View File

@@ -22,7 +22,7 @@ const DataDetail = ({ dataName, dataId }: { dataName: string; dataId: string })
return (
<Card py={4} h={'100%'} display={'flex'} flexDirection={'column'}>
<Box px={6} fontSize={'xl'} fontWeight={'bold'}>
{dataName}
{dataName}
</Box>
<ScrollData
flex={'1 0 0'}
@@ -38,8 +38,13 @@ const DataDetail = ({ dataName, dataId }: { dataName: string; dataId: string })
<Box key={item._id}>
{item.result.map((result, i) => (
<Box key={i} mb={3}>
<Box fontWeight={'bold'}>Q: {result.q}</Box>
<Box>A: {result.a}</Box>
{item.type === 'QA' && (
<>
<Box fontWeight={'bold'}>Q: {result.q}</Box>
<Box>A: {result.a}</Box>
</>
)}
{item.type === 'abstract' && <Box fontSize={'sm'}>{result.abstract}</Box>}
</Box>
))}
</Box>

View File

@@ -28,13 +28,14 @@ import { useRouter } from 'next/router';
import { useConfirm } from '@/hooks/useConfirm';
import { useRequest } from '@/hooks/useRequest';
import { DataItemSchema } from '@/types/mongoSchema';
import { DataTypeTextMap } from '@/constants/data';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('.,', 1);
const CreateDataModal = dynamic(() => import('./components/CreateDataModal'));
const ImportDataModal = dynamic(() => import('./components/ImportDataModal'));
export type ExportDataType = 'jsonl';
export type ExportDataType = 'jsonl' | 'txt';
const DataList = () => {
const router = useRouter();
@@ -84,21 +85,26 @@ const DataList = () => {
let text = '';
// 生成 jsonl
data.forEach((item) => {
const result = JSON.stringify({
prompt: `${item.q.toLocaleLowerCase()}${nanoid()}</s>`,
completion: ` ${item.a}###`
});
text += `${result}\n`;
if (res.type === 'jsonl' && item.q && item.a) {
const result = JSON.stringify({
prompt: `${item.q.toLocaleLowerCase()}${nanoid()}</s>`,
completion: ` ${item.a}###`
});
text += `${result}\n`;
} else if (res.type === 'txt' && item.abstract) {
text += `${item.abstract}\n`;
}
});
// 去掉最后一个 \n
text = text.substring(0, text.length - 1);
// 导出为文件
const blob = new Blob([text], { type: 'application/json;charset=utf-8' });
// 创建下载链接
const downloadLink = document.createElement('a');
downloadLink.href = window.URL.createObjectURL(blob);
downloadLink.download = 'file.jsonl';
downloadLink.download = `data.${res.type}`;
// 添加链接到页面并触发下载
document.body.appendChild(downloadLink);
@@ -138,6 +144,7 @@ const DataList = () => {
<Thead>
<Tr>
<Th></Th>
<Th></Th>
<Th></Th>
<Th> / </Th>
<Th></Th>
@@ -158,6 +165,7 @@ const DataList = () => {
}}
/>
</Td>
<Td>{DataTypeTextMap[item.type || 'QA']}</Td>
<Td>{dayjs(item.createTime).format('YYYY/MM/DD HH:mm')}</Td>
<Td>
{item.trainingData} / {item.totalData}
@@ -187,9 +195,18 @@ const DataList = () => {
</MenuButton>
<MenuList>
<MenuItem onClick={() => handleExportData({ data: item, type: 'jsonl' })}>
jsonl
</MenuItem>
{item.type === 'QA' && (
<MenuItem
onClick={() => handleExportData({ data: item, type: 'jsonl' })}
>
jsonl
</MenuItem>
)}
{item.type === 'abstract' && (
<MenuItem onClick={() => handleExportData({ data: item, type: 'txt' })}>
txt
</MenuItem>
)}
</MenuList>
</Menu>

View File

@@ -97,7 +97,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={500}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'

View File

@@ -0,0 +1,177 @@
import { DataItem } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
export async function generateAbstract(next = false): Promise<any> {
if (global.generatingAbstract && !next) return;
global.generatingAbstract = true;
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `我会向你发送一段长文本请从中总结出3~10个摘要尽量详细请按以下格式返回: "(1):"\n"(2):"\n"(3):"\n`
};
let dataItem: DataItemSchema | null = null;
try {
// 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({
status: { $ne: 0 },
times: { $gt: 0 },
type: 'abstract'
});
if (!dataItem) {
console.log('没有需要生成 【摘要】 的数据');
global.generatingAbstract = false;
return;
}
// 更新状态为生成中
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 2
});
// 获取 openapi Key
let userApiKey, systemKey;
try {
const key = await getOpenApiKey(dataItem.userId);
userApiKey = key.userApiKey;
systemKey = key.systemKey;
} catch (error) {
// 余额不够了, 把用户所有记录改成闲置
await DataItem.updateMany({
userId: dataItem.userId,
status: 0
});
throw new Error('获取 openai key 失败');
}
console.log('正在生成一组摘要, ID:', dataItem._id);
const startTime = Date.now();
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取摘要
const abstractResponse = await Promise.allSettled(
[0.5, 1].map((temperature) =>
chatAPI.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: temperature,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
)
)
);
// 过滤出成功的响应
const successAbstracts = abstractResponse.filter((item) => item.status === 'fulfilled');
// 提取摘要内容
const rawContents: string[] = successAbstracts.map(
(item: any) => item?.value?.data.choices[0].message?.content || ''
);
// 从 content 中提取摘要内容
const splitContents = rawContents.map((content) => splitText(content)).flat();
// 生成词向量
const vectorResponse = await Promise.allSettled(
splitContents.map((item) =>
chatAPI.createEmbedding({
model: 'text-embedding-ada-002',
input: item.abstract
})
)
);
// 筛选成功的向量请求
const vectorSuccessResponse = vectorResponse
.map((item: any, i) => {
if (item.status !== 'fulfilled') return '';
return {
abstract: splitContents[i].abstract,
abstractVector: item?.value?.data?.data?.[0]?.embedding
};
})
.filter((item) => item);
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 0,
$push: {
rawResponse: {
$each: rawContents
},
result: {
$each: vectorSuccessResponse
}
}
});
// 计费
!userApiKey &&
splitContents.length > 0 &&
pushSplitDataBill({
userId: dataItem.userId,
type: 'abstract',
text:
systemPrompt.content +
dataItem.text +
rawContents.join('') +
rawContents.join('').substring(0, Math.floor(dataItem.text.length / 10)) // 向量价格是gpt35的1/10
});
console.log(
'生成摘要成功time:',
`${(Date.now() - startTime) / 1000}s`,
'摘要数量:',
splitContents.length
);
} catch (error: any) {
console.log('error: 生成摘要错误', dataItem?._id);
console.log('response:', error);
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
$inc: {
// 剩余尝试次数-1
times: -1
}
});
}
}
generateAbstract(true);
}
/**
* 检查文本是否按格式返回
*/
function splitText(text: string) {
const regex = /\(\d+\):(\s*)(.*)(\s*)/g;
const matches = text.matchAll(regex); // 获取所有匹配到的结果
const result = []; // 存储最终的结果
for (const match of matches) {
if (match[2]) {
result.push({
abstract: match[2] as string
});
}
}
return result;
}

View File

@@ -20,7 +20,8 @@ export async function generateQA(next = false): Promise<any> {
// 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({
status: { $ne: 0 },
times: { $gt: 0 }
times: { $gt: 0 },
type: 'QA'
});
if (!dataItem) {
@@ -49,62 +50,72 @@ export async function generateQA(next = false): Promise<any> {
throw new Error('获取 openai key 失败');
}
console.log('正在生成一QA, ID:', dataItem._id, 'temperature: ', dataItem.temperature / 100);
console.log('正在生成一QA, ID:', dataItem._id);
const startTime = Date.now();
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取回答
const response = await chatAPI.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: dataItem.temperature / 100,
n: 1,
messages: [
systemPrompt,
const response = await Promise.allSettled(
[0, 0.5, 0.8].map((temperature) =>
chatAPI.createChatCompletion(
{
role: 'user',
content: dataItem.text
model: ChatModelNameEnum.GPT35,
temperature: temperature,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
]
},
{
timeout: 120000,
httpsAgent
}
)
)
);
// 过滤出成功的响应
const successResponse = response.filter((item) => item.status === 'fulfilled');
// 提取响应内容
const rawContents: string[] = successResponse.map(
(item: any) => item?.value?.data.choices[0].message?.content || ''
);
const content = response.data.choices[0].message?.content;
// 从 content 中提取 QA
const splitResponse = splitText(content || '');
const splitResponses = rawContents.map((content) => splitText(content)).flat();
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.temperature >= 90 ? 0 : 1, // 需要生成 4 组内容。0,0.3,0.6,0.9
temperature: dataItem.temperature >= 90 ? dataItem.temperature : dataItem.temperature + 30,
status: 0,
$push: {
rawResponse: content,
rawResponse: {
$each: rawContents
},
result: {
$each: splitResponse
$each: splitResponses
}
}
});
// 计费
!userApiKey &&
splitResponse.length > 0 &&
splitResponses.length > 0 &&
pushSplitDataBill({
userId: dataItem.userId,
text: systemPrompt.content + dataItem.text + content
type: 'QA',
text: systemPrompt.content + dataItem.text + rawContents.join('')
});
console.log(
'生成QA成功time:',
`${(Date.now() - startTime) / 1000}s`,
'QA数量',
splitResponse.length
splitResponses.length
);
} catch (error: any) {
console.log('error: 生成QA错误', dataItem?._id);
console.log('response:', error?.response);
// 重置状态
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行

View File

@@ -2,6 +2,7 @@ import { connectToDatabase, Bill, User } from '../mongo';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode } from 'gpt-token-utils';
import { formatPrice } from '@/utils/user';
import type { DataType } from '@/types/data';
export const pushChatBill = async ({
modelName,
@@ -59,7 +60,15 @@ export const pushChatBill = async ({
}
};
export const pushSplitDataBill = async ({ userId, text }: { userId: string; text: string }) => {
export const pushSplitDataBill = async ({
userId,
text,
type
}: {
userId: string;
text: string;
type: DataType;
}) => {
await connectToDatabase();
let billId;
@@ -83,7 +92,7 @@ export const pushSplitDataBill = async ({ userId, text }: { userId: string; text
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: 'splitData',
type,
modelName: ChatModelNameEnum.GPT35,
textLen: text.length,
tokenLen: tokens.length,

View File

@@ -1,5 +1,6 @@
import { Schema, model, models, Model } from 'mongoose';
import { DataItemSchema as Datatype } from '@/types/mongoSchema';
import { DataSchema as Datatype } from '@/types/mongoSchema';
import { DataTypeTextMap } from '@/constants/data';
const DataSchema = new Schema({
userId: {
@@ -15,6 +16,11 @@ const DataSchema = new Schema({
type: Date,
default: () => new Date()
},
type: {
type: String,
required: true,
enum: Object.keys(DataTypeTextMap)
},
isDeleted: {
type: Boolean,
default: false

View File

@@ -1,5 +1,6 @@
import type { DataItemSchema as DataItemType } from '@/types/mongoSchema';
import { Schema, model, models, Model } from 'mongoose';
import { DataTypeTextMap } from '@/constants/data';
const DataItemSchema = new Schema({
userId: {
@@ -12,19 +13,23 @@ const DataItemSchema = new Schema({
ref: 'data',
required: true
},
type: {
type: String,
required: true,
enum: Object.keys(DataTypeTextMap)
},
times: {
// 剩余重试次数
type: Number,
default: 3
},
text: {
// 文本内容
type: String,
required: true
},
temperature: {
type: Number,
required: true
},
rawResponse: {
// 原始拆分结果
type: [String],
default: []
},
@@ -33,11 +38,21 @@ const DataItemSchema = new Schema({
{
q: {
type: String,
required: true
default: ''
},
a: {
type: String,
required: true
default: ''
},
abstract: {
// 摘要
type: String,
default: ''
},
abstractVector: {
// 摘要对应的向量
type: [Number],
default: []
}
}
],

View File

@@ -1,5 +1,7 @@
import mongoose from 'mongoose';
import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract';
/**
* 连接 MongoDB 数据库
*/
@@ -24,8 +26,8 @@ export async function connectToDatabase(): Promise<void> {
global.mongodb = null;
}
// 递归 QA 生成
generateQA();
generateAbstract();
}
export * from './models/authCode';

2
src/types/data.d.ts vendored
View File

@@ -1,5 +1,7 @@
import type { DataSchema } from './mongoSchema';
export type DataType = 'QA' | 'abstract';
export interface DataListItem extends DataSchema {
trainingData: number;
totalData: number;

View File

@@ -3,6 +3,7 @@ import type { Mongoose } from 'mongoose';
declare global {
var mongodb: Mongoose | string | null;
var generatingQA: boolean;
var generatingAbstract: boolean;
var QRCode: any;
interface Window {
['pdfjs-dist/build/pdf']: any;

View File

@@ -1,5 +1,6 @@
import type { ChatItemType } from './chat';
import { ModelStatusEnum, TrainingStatusEnum, ChatModelNameEnum } from '@/constants/model';
import type { DataType } from './data';
export type ServiceName = 'openai';
@@ -102,19 +103,21 @@ export interface DataSchema {
userId: string;
name: string;
createTime: string;
type: DataType;
}
export interface DataItemSchema {
_id: string;
userId: string;
dataId: string;
type: DataType;
times: number;
temperature: number;
text: string;
rawResponse: string[];
result: {
q: string;
a: string;
q?: string;
a?: string;
abstract?: string;
}[];
status: 0 | 1 | 2;
}