feat: insert data de-weight;perf: input queue

This commit is contained in:
archer
2023-05-28 20:13:19 +08:00
parent 7e99f905bc
commit 516618b0cd
12 changed files with 187 additions and 105 deletions

View File

@@ -2,7 +2,10 @@ import { GET, POST, PUT, DELETE } from '../request';
import type { KbItemType } from '@/types/plugin'; import type { KbItemType } from '@/types/plugin';
import { RequestPaging } from '@/types/index'; import { RequestPaging } from '@/types/index';
import { TrainingModeEnum } from '@/constants/plugin'; import { TrainingModeEnum } from '@/constants/plugin';
import { Props as PushDataProps } from '@/pages/api/openapi/kb/pushData'; import {
Props as PushDataProps,
Response as PushDateResponse
} from '@/pages/api/openapi/kb/pushData';
export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string };
@@ -46,7 +49,8 @@ export const getKbDataItemById = (dataId: string) =>
/** /**
* 直接push数据 * 直接push数据
*/ */
export const postKbDataFromList = (data: PushDataProps) => POST(`/openapi/kb/pushData`, data); export const postKbDataFromList = (data: PushDataProps) =>
POST<PushDateResponse>(`/openapi/kb/pushData`, data);
/** /**
* 更新一条数据 * 更新一条数据

View File

@@ -16,6 +16,10 @@ export type Props = {
prompt?: string; prompt?: string;
}; };
export type Response = {
insertLen: number;
};
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { kbId, data, mode, prompt } = req.body as Props; const { kbId, data, mode, prompt } = req.body as Props;
@@ -28,7 +32,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 凭证校验 // 凭证校验
const { userId } = await authUser({ req }); const { userId } = await authUser({ req });
jsonRes(res, { jsonRes<Response>(res, {
data: await pushDataToKb({ data: await pushDataToKb({
kbId, kbId,
data, data,
@@ -51,16 +55,12 @@ export async function pushDataToKb({
data, data,
mode, mode,
prompt prompt
}: { userId: string } & Props) { }: { userId: string } & Props): Promise<Response> {
await authKb({ await authKb({
userId, userId,
kbId kbId
}); });
if (data.length === 0) {
return {};
}
// 过滤重复的 qa 内容 // 过滤重复的 qa 内容
const set = new Set(); const set = new Set();
const filterData: { const filterData: {
@@ -75,41 +75,54 @@ export async function pushDataToKb({
set.add(text); set.add(text);
} }
}); });
// 数据库去重 // 数据库去重
// const searchRes = await Promise.allSettled( const insertData = (
// data.map(async ({ q, a = '' }) => { await Promise.allSettled(
// if (!q) { filterData.map(async ({ q, a = '' }) => {
// return Promise.reject('q为空'); if (mode !== TrainingModeEnum.index) {
// } return Promise.resolve({
q,
a
});
}
// q = q.replace(/\\n/g, '\n'); if (!q) {
// a = a.replace(/\\n/g, '\n'); return Promise.reject('q为空');
}
// // Exactly the same data, not push q = q.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// try { a = a.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
// const count = await PgClient.count('modelData', {
// where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]]
// });
// if (count > 0) { // Exactly the same data, not push
// return Promise.reject('已经存在'); try {
// } const { rows } = await PgClient.query(`
// } catch (error) { SELECT COUNT(*) > 0 AS exists
// error; FROM modelData
// } WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
// return Promise.resolve({ `);
// q, const exists = rows[0]?.exists || false;
// a
// }); if (exists) {
// }) return Promise.reject('已经存在');
// ); }
// const filterData = searchRes } catch (error) {
// .filter((item) => item.status === 'fulfilled') console.log(error);
// .map<{ q: string; a: string }>((item: any) => item.value); error;
}
return Promise.resolve({
q,
a
});
})
)
)
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入记录 // 插入记录
await TrainingData.insertMany( await TrainingData.insertMany(
data.map((item) => ({ insertData.map((item) => ({
q: item.q, q: item.q,
a: item.a, a: item.a,
userId, userId,
@@ -119,9 +132,11 @@ export async function pushDataToKb({
})) }))
); );
startQueue(); insertData.length > 0 && startQueue();
return {}; return {
insertLen: insertData.length
};
} }
export const config = { export const config = {

View File

@@ -32,10 +32,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
await PgClient.update('modelData', { await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]], where: [['id', dataId], 'AND', ['user_id', userId]],
values: [ values: [
{ key: 'a', value: a }, { key: 'a', value: a.replace(/'/g, '"') },
...(q ...(q
? [ ? [
{ key: 'q', value: q }, { key: 'q', value: q.replace(/'/g, '"') },
{ key: 'vector', value: `[${vector[0]}]` } { key: 'vector', value: `[${vector[0]}]` }
] ]
: []) : [])

View File

@@ -54,7 +54,7 @@ const InputDataModal = ({
setLoading(true); setLoading(true);
try { try {
const res = await postKbDataFromList({ const { insertLen } = await postKbDataFromList({
kbId, kbId,
data: [ data: [
{ {
@@ -65,14 +65,22 @@ const InputDataModal = ({
mode: TrainingModeEnum.index mode: TrainingModeEnum.index
}); });
if (insertLen === 0) {
toast({ toast({
title: res === 0 ? '可能已存在完全一致的数据' : '导入数据成功,需要一段时间训练', title: '已存在完全一致的数据',
status: 'warning'
});
} else {
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success' status: 'success'
}); });
reset({ reset({
a: '', a: '',
q: '' q: ''
}); });
}
onSuccess(); onSuccess();
} catch (err: any) { } catch (err: any) {
toast({ toast({

View File

@@ -37,6 +37,7 @@ const SelectJsonModal = ({
const { toast } = useToast(); const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: '.csv', multiple: false }); const { File, onOpen } = useSelectFile({ fileType: '.csv', multiple: false });
const [fileData, setFileData] = useState<{ q: string; a: string }[]>([]); const [fileData, setFileData] = useState<{ q: string; a: string }[]>([]);
const [successData, setSuccessData] = useState(0);
const { openConfirm, ConfirmChild } = useConfirm({ const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该数据集?' content: '确认导入该数据集?'
}); });
@@ -67,27 +68,35 @@ const SelectJsonModal = ({
[setSelecting, toast] [setSelecting, toast]
); );
const { mutate, isLoading } = useMutation({ const { mutate, isLoading: uploading } = useMutation({
mutationFn: async () => { mutationFn: async () => {
if (!fileData || fileData.length === 0) return; if (!fileData || fileData.length === 0) return;
const res = await postKbDataFromList({ let success = 0;
// subsection import
const step = 50;
for (let i = 0; i < fileData.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId, kbId,
data: fileData, data: fileData.slice(i, i + step),
mode: TrainingModeEnum.index mode: TrainingModeEnum.index
}); });
success += insertLen || 0;
setSuccessData((state) => state + step);
}
toast({ toast({
title: `导入数据成功,最终导入: ${res || 0} 条数据。需要一段时间训练`, title: `导入数据成功,最终导入: ${success} 条数据。需要一段时间训练`,
status: 'success', status: 'success',
duration: 4000 duration: 4000
}); });
onClose(); onClose();
onSuccess(); onSuccess();
}, },
onError() { onError(err) {
toast({ toast({
title: '导入文件失败', title: getErrText(err, '导入文件失败'),
status: 'error' status: 'error'
}); });
} }
@@ -121,15 +130,15 @@ const SelectJsonModal = ({
csv模板 csv模板
</Box> </Box>
<Flex alignItems={'center'}> <Flex alignItems={'center'}>
<Button isLoading={selecting} onClick={onOpen}> <Button isLoading={selecting} isDisabled={uploading} onClick={onOpen}>
csv csv
</Button> </Button>
<Box ml={4}> {fileData.length} </Box> <Box ml={4}> {fileData.length} 100</Box>
</Flex> </Flex>
</Box> </Box>
<Box flex={'3 0 0'} h={'100%'} overflow={'auto'} p={2} backgroundColor={'blackAlpha.50'}> <Box flex={'3 0 0'} h={'100%'} overflow={'auto'} p={2} backgroundColor={'blackAlpha.50'}>
{fileData.map((item, index) => ( {fileData.slice(0, 100).map((item, index) => (
<Box key={index}> <Box key={index}>
<Box> <Box>
Q{index + 1}. {item.q} Q{index + 1}. {item.q}
@@ -144,15 +153,15 @@ const SelectJsonModal = ({
<Flex px={6} pt={2} pb={4}> <Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box> <Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}> <Button variant={'outline'} isLoading={uploading} mr={3} onClick={onClose}>
</Button> </Button>
<Button <Button isDisabled={fileData.length === 0 || uploading} onClick={openConfirm(mutate)}>
isLoading={isLoading} {uploading ? (
isDisabled={fileData.length === 0} <Box>{Math.round((successData / fileData.length) * 100)}%</Box>
onClick={openConfirm(mutate)} ) : (
> '确认导入'
)}
</Button> </Button>
</Flex> </Flex>
</ModalContent> </ModalContent>

View File

@@ -55,9 +55,14 @@ const SelectFileModal = ({
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [mode, setMode] = useState<`${TrainingModeEnum}`>(TrainingModeEnum.index); const [mode, setMode] = useState<`${TrainingModeEnum}`>(TrainingModeEnum.index);
const [fileTextArr, setFileTextArr] = useState<string[]>(['']); const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({ const [splitRes, setSplitRes] = useState<{
tokens: number;
chunks: string[];
successChunks: number;
}>({
tokens: 0, tokens: 0,
chunks: [] chunks: [],
successChunks: 0
}); });
const { openConfirm, ConfirmChild } = useConfirm({ const { openConfirm, ConfirmChild } = useConfirm({
content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${ content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${
@@ -104,19 +109,30 @@ const SelectFileModal = ({
[toast] [toast]
); );
const { mutate, isLoading } = useMutation({ const { mutate, isLoading: uploading } = useMutation({
mutationFn: async () => { mutationFn: async () => {
if (splitRes.chunks.length === 0) return; if (splitRes.chunks.length === 0) return;
await postKbDataFromList({ // subsection import
let success = 0;
const step = 50;
for (let i = 0; i < splitRes.chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId, kbId,
data: splitRes.chunks.map((text) => ({ q: text, a: '' })), data: splitRes.chunks.slice(i, i + step).map((text) => ({ q: text, a: '' })),
prompt: `下面是"${prompt || '一段长文本'}"`, prompt: `下面是"${prompt || '一段长文本'}"`,
mode mode
}); });
success += insertLen;
setSplitRes((state) => ({
...state,
successChunks: state.successChunks + step
}));
}
toast({ toast({
title: '导入数据成功,需要一段拆解和训练. 重复数据会自动删除', title: `去重后共导入 ${success} 条数据,需要一段拆解和训练.`,
status: 'success' status: 'success'
}); });
onClose(); onClose();
@@ -148,7 +164,8 @@ const SelectFileModal = ({
setSplitRes({ setSplitRes({
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0), tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks: splitRes.map((item) => item.chunks).flat() chunks: splitRes.map((item) => item.chunks).flat(),
successChunks: 0
}); });
await promise; await promise;
@@ -235,6 +252,11 @@ const SelectFileModal = ({
...fileTextArr.slice(i + 1) ...fileTextArr.slice(i + 1)
]); ]);
}} }}
onBlur={(e) => {
if (fileTextArr.length > 1 && e.target.value === '') {
setFileTextArr((state) => [...state.slice(0, i), ...state.slice(i + 1)]);
}
}}
/> />
</Box> </Box>
))} ))}
@@ -242,19 +264,22 @@ const SelectFileModal = ({
</ModalBody> </ModalBody>
<Flex px={6} pt={2} pb={4}> <Flex px={6} pt={2} pb={4}>
<Button isLoading={btnLoading} onClick={onOpen}> <Button isLoading={btnLoading} isDisabled={uploading} onClick={onOpen}>
</Button> </Button>
<Box flex={1}></Box> <Box flex={1}></Box>
<Button variant={'outline'} colorScheme={'gray'} mr={3} onClick={onClose}> <Button variant={'outline'} isLoading={uploading} mr={3} onClick={onClose}>
</Button> </Button>
<Button <Button
isLoading={isLoading || btnLoading} isDisabled={uploading || btnLoading || fileTextArr[0] === ''}
isDisabled={isLoading || btnLoading || fileTextArr[0] === ''}
onClick={onclickImport} onClick={onclickImport}
> >
{uploading ? (
<Box>{Math.round((splitRes.successChunks / splitRes.chunks.length) * 100)}%</Box>
) : (
'确认导入'
)}
</Button> </Button>
</Flex> </Flex>
</ModalContent> </ModalContent>

View File

@@ -24,10 +24,10 @@ export const openaiError: Record<string, string> = {
'Bad Request': 'Bad Request~ 可能内容太多了', 'Bad Request': 'Bad Request~ 可能内容太多了',
'Bad Gateway': '网关异常,请重试' 'Bad Gateway': '网关异常,请重试'
}; };
export const openaiError2: Record<string, string> = { export const openaiAccountError: Record<string, string> = {
insufficient_quota: 'API 余额不足', // insufficient_quota: 'API 余额不足',
billing_not_active: 'openai 账号异常', invalid_api_key: 'openai 账号异常'
invalid_request_error: '无效的 openai 请求' // invalid_request_error: '无效的 openai 请求'
}; };
export const proxyError: Record<string, boolean> = { export const proxyError: Record<string, boolean> = {
ECONNABORTED: true, ECONNABORTED: true,

View File

@@ -2,7 +2,7 @@ import { TrainingData } from '@/service/mongo';
import { getApiKey } from '../utils/auth'; import { getApiKey } from '../utils/auth';
import { OpenAiChatEnum } from '@/constants/model'; import { OpenAiChatEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushSplitDataBill } from '@/service/events/pushBill';
import { openaiError2 } from '../errorCode'; import { openaiAccountError } from '../errorCode';
import { modelServiceToolMap } from '../utils/chat'; import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat'; import { ChatRoleEnum } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
@@ -81,8 +81,6 @@ export async function generateQA(): Promise<any> {
type: 'training' type: 'training'
}); });
console.log(`正在生成一组QA。ID: ${trainingId}`);
const startTime = Date.now(); const startTime = Date.now();
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
@@ -137,7 +135,7 @@ A2:
const responseList = response.map((item) => item.result).flat(); const responseList = response.map((item) => item.result).flat();
// 创建 向量生成 队列 // 创建 向量生成 队列
pushDataToKb({ await pushDataToKb({
kbId, kbId,
data: responseList, data: responseList,
userId, userId,
@@ -161,8 +159,16 @@ A2:
console.log('生成QA错误:', err); console.log('生成QA错误:', err);
} }
// openai 账号异常或者账号余额不足,删除任务 // message error or openai account error
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { if (
err?.message === 'invalid message format' ||
openaiAccountError[err?.response?.data?.error?.code]
) {
await TrainingData.findByIdAndRemove(trainingId);
}
// 账号余额不足,删除任务
if (err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务'); console.log('余额不足,删除向量生成任务');
await TrainingData.deleteMany({ await TrainingData.deleteMany({
userId userId

View File

@@ -1,4 +1,4 @@
import { openaiError2 } from '../errorCode'; import { openaiAccountError } from '../errorCode';
import { insertKbItem } from '@/service/pg'; import { insertKbItem } from '@/service/pg';
import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding';
import { TrainingData } from '../models/trainingData'; import { TrainingData } from '../models/trainingData';
@@ -111,8 +111,17 @@ export async function generateVector(): Promise<any> {
console.log('生成向量错误:', err); console.log('生成向量错误:', err);
} }
// openai 账号异常或者账号余额不足,删除任务 // message error or openai account error
if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { if (
err?.message === 'invalid message format' ||
openaiAccountError[err?.response?.data?.error?.code]
) {
console.log('删除一个任务');
await TrainingData.findByIdAndRemove(trainingId);
}
// 账号余额不足,删除任务
if (err === ERROR_ENUM.insufficientQuota) {
console.log('余额不足,删除向量生成任务'); console.log('余额不足,删除向量生成任务');
await TrainingData.deleteMany({ await TrainingData.deleteMany({
userId userId

View File

@@ -134,9 +134,9 @@ export const pushGenerateVectorBill = async ({
text: string; text: string;
tokenLen: number; tokenLen: number;
}) => { }) => {
console.log( // console.log(
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}` // `vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
); // );
if (!isPay) return; if (!isPay) return;
let billId; let billId;

View File

@@ -177,8 +177,8 @@ export const insertKbItem = ({
values: data.map((item) => [ values: data.map((item) => [
{ key: 'user_id', value: userId }, { key: 'user_id', value: userId },
{ key: 'kb_id', value: kbId }, { key: 'kb_id', value: kbId },
{ key: 'q', value: item.q }, { key: 'q', value: item.q.replace(/'/g, '"') },
{ key: 'a', value: item.a }, { key: 'a', value: item.a.replace(/'/g, '"') },
{ key: 'vector', value: `[${item.vector}]` } { key: 'vector', value: `[${item.vector}]` }
]) ])
}); });

View File

@@ -1,5 +1,11 @@
import { NextApiResponse } from 'next'; import { NextApiResponse } from 'next';
import { openaiError, openaiError2, proxyError, ERROR_RESPONSE, ERROR_ENUM } from './errorCode'; import {
openaiError,
openaiAccountError,
proxyError,
ERROR_RESPONSE,
ERROR_ENUM
} from './errorCode';
import { clearCookie } from './utils/tools'; import { clearCookie } from './utils/tools';
export interface ResponseType<T = any> { export interface ResponseType<T = any> {
@@ -40,8 +46,8 @@ export const jsonRes = <T = any>(
msg = '接口连接异常'; msg = '接口连接异常';
} else if (error?.response?.data?.error?.message) { } else if (error?.response?.data?.error?.message) {
msg = error?.response?.data?.error?.message; msg = error?.response?.data?.error?.message;
} else if (openaiError2[error?.response?.data?.error?.type]) { } else if (openaiAccountError[error?.response?.data?.error?.code]) {
msg = openaiError2[error?.response?.data?.error?.type]; msg = openaiAccountError[error?.response?.data?.error?.code];
} else if (openaiError[error?.response?.statusText]) { } else if (openaiError[error?.response?.statusText]) {
msg = openaiError[error.response.statusText]; msg = openaiError[error.response.statusText];
} }