feat: kb data source

This commit is contained in:
archer
2023-05-30 23:26:29 +08:00
parent 176c5a4d79
commit 746b9af2de
18 changed files with 141 additions and 110 deletions

View File

@@ -8,8 +8,9 @@ CREATE TABLE IF NOT EXISTS modelData (
vector VECTOR(1536) NOT NULL,
user_id VARCHAR(50) NOT NULL,
kb_id VARCHAR(50) NOT NULL,
source VARCHAR(100),
q TEXT NOT NULL,
a TEXT NOT NULL
a TEXT NOT NULL,
);
-- 索引设置,按需取
-- CREATE INDEX IF NOT EXISTS modelData_userId_index ON modelData USING HASH (user_id);

View File

@@ -4,7 +4,6 @@ import type { ImageProps } from '@chakra-ui/react';
import { LOGO_ICON } from '@/constants/chat';
const Avatar = ({ w = '30px', ...props }: ImageProps) => {
console.log(props.src);
return (
<Image
fallbackSrc={LOGO_ICON}

View File

@@ -1,5 +1,4 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import type { KbDataItemType } from '@/types/plugin';
import { jsonRes } from '@/service/response';
import { connectToDatabase, TrainingData } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
@@ -9,9 +8,11 @@ import { TrainingModeEnum } from '@/constants/plugin';
import { startQueue } from '@/service/utils/tools';
import { PgClient } from '@/service/pg';
type DateItemType = { a: string; q: string; source?: string };
export type Props = {
kbId: string;
data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[];
data: DateItemType[];
mode: `${TrainingModeEnum}`;
prompt?: string;
};
@@ -63,10 +64,7 @@ export async function pushDataToKb({
// 过滤重复的 qa 内容
const set = new Set();
const filterData: {
a: string;
q: string;
}[] = [];
const filterData: DateItemType[] = [];
data.forEach((item) => {
const text = item.q + item.a;
@@ -79,11 +77,12 @@ export async function pushDataToKb({
// 数据库去重
const insertData = (
await Promise.allSettled(
filterData.map(async ({ q, a = '' }) => {
filterData.map(async ({ q, a = '', source }) => {
if (mode !== TrainingModeEnum.index) {
return Promise.resolve({
q,
a
a,
source
});
}
@@ -112,19 +111,21 @@ export async function pushDataToKb({
}
return Promise.resolve({
q,
a
a,
source
});
})
)
)
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
.map<DateItemType>((item: any) => item.value);
// 插入记录
await TrainingData.insertMany(
insertData.map((item) => ({
q: item.q,
a: item.a,
source: item.source,
userId,
kbId,
mode,

View File

@@ -32,6 +32,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'source', value: '手动修改' },
{ key: 'a', value: a.replace(/'/g, '"') },
...(q
? [

View File

@@ -3,7 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import type { PgKBDataItemType } from '@/types/pg';
import type { KbDataItemType } from '@/types/plugin';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -21,8 +21,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const where: any = [['user_id', userId], 'AND', ['id', dataId]];
const searchRes = await PgClient.select<PgKBDataItemType>('modelData', {
fields: ['id', 'q', 'a'],
const searchRes = await PgClient.select<KbDataItemType>('modelData', {
fields: ['id', 'q', 'a', 'source'],
where,
limit: 1
});

View File

@@ -3,7 +3,7 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authUser } from '@/service/utils/auth';
import { PgClient } from '@/service/pg';
import type { PgKBDataItemType } from '@/types/pg';
import type { KbDataItemType } from '@/types/plugin';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -31,11 +31,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
['user_id', userId],
'AND',
['kb_id', kbId],
...(searchText ? ['AND', `(q LIKE '%${searchText}%' OR a LIKE '%${searchText}%')`] : [])
...(searchText
? [
'AND',
`(q LIKE '%${searchText}%' OR a LIKE '%${searchText}%' OR source LIKE '%${searchText}%')`
]
: [])
];
const searchRes = await PgClient.select<PgKBDataItemType>('modelData', {
fields: ['id', 'q', 'a'],
const searchRes = await PgClient.select<KbDataItemType>('modelData', {
fields: ['id', 'q', 'a', 'source'],
where,
order: [{ field: 'id', mode: 'DESC' }],
limit: pageSize,

View File

@@ -1,4 +1,4 @@
import React, { useCallback, useState, useRef } from 'react';
import React, { useCallback, useState, useRef, useEffect } from 'react';
import {
Box,
TableContainer,
@@ -56,7 +56,7 @@ const DataCard = ({ kbId }: { kbId: string }) => {
const { toast } = useToast();
const {
data: modelDataList,
data: kbDataList,
isLoading,
Pagination,
total,
@@ -72,11 +72,6 @@ const DataCard = ({ kbId }: { kbId: string }) => {
defaultRequest: false
});
useQuery(['getKbData', kbId], () => {
getData(1);
return null;
});
const [editInputData, setEditInputData] = useState<InputDataType>();
const {
@@ -101,20 +96,14 @@ const DataCard = ({ kbId }: { kbId: string }) => {
);
const refetchData = useCallback(
(num = 1) => {
(num = pageNum) => {
getData(num);
refetch();
return null;
},
[getData, refetch]
[getData, pageNum, refetch]
);
// interval get data
useQuery(['refetchData'], () => refetchData(pageNum), {
refetchInterval: 5000,
enabled: qaListLen > 0 || vectorListLen > 0
});
// get al data and export csv
const { mutate: onclickExport, isLoading: isLoadingExport = false } = useMutation({
mutationFn: () => getExportDataList(kbId),
@@ -148,6 +137,17 @@ const DataCard = ({ kbId }: { kbId: string }) => {
}
});
// interval get data
useQuery(['refetchData'], () => refetchData(1), {
refetchInterval: 5000,
enabled: qaListLen > 0 || vectorListLen > 0
});
useQuery(['getKbData', kbId], () => {
setSearchText('');
getData(1);
return null;
});
return (
<Box position={'relative'}>
<Flex>
@@ -239,18 +239,22 @@ const DataCard = ({ kbId }: { kbId: string }) => {
</Tooltip>
</Th>
<Th></Th>
<Th></Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item.id}>
{kbDataList.map((item) => (
<Tr key={item.id} fontSize={'sm'}>
<Td>
<Box {...tdStyles.current}>{item.q}</Box>
</Td>
<Td>
<Box {...tdStyles.current}>{item.a || '-'}</Box>
</Td>
<Td maxW={'15%'} whiteSpace={'pre-wrap'} userSelect={'all'}>
{item.source?.trim() || '-'}
</Td>
<Td>
<IconButton
mr={5}

View File

@@ -56,13 +56,14 @@ const InputDataModal = ({
try {
const { insertLen } = await postKbDataFromList({
kbId,
mode: TrainingModeEnum.index,
data: [
{
a: e.a,
q: e.q
q: e.q,
source: '手动录入'
}
],
mode: TrainingModeEnum.index
]
});
if (insertLen === 0) {

View File

@@ -37,6 +37,7 @@ const SelectJsonModal = ({
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: '.csv', multiple: false });
const [fileData, setFileData] = useState<{ q: string; a: string }[]>([]);
const [fileName, setFileName] = useState('');
const [successData, setSuccessData] = useState(0);
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该数据集?'
@@ -46,6 +47,7 @@ const SelectJsonModal = ({
async (e: File[]) => {
const file = e[0];
setSelecting(true);
setFileName(file.name);
try {
const { header, data } = await readCsvContent(file);
if (header[0] !== 'question' || header[1] !== 'answer') {
@@ -75,11 +77,14 @@ const SelectJsonModal = ({
let success = 0;
// subsection import
const step = 50;
const step = 100;
for (let i = 0; i < fileData.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,
data: fileData.slice(i, i + step),
data: fileData.slice(i, i + step).map((item) => ({
...item,
source: fileName
})),
mode: TrainingModeEnum.index
});
success += insertLen || 0;
@@ -129,13 +134,14 @@ const SelectJsonModal = ({
>
csv模板
</Box>
<Flex alignItems={'center'}>
<Box>
<Button isLoading={selecting} isDisabled={uploading} onClick={onOpen}>
csv
</Button>
<Box ml={4}> {fileData.length} 100</Box>
</Flex>
<Box mt={4}>
{fileName} {fileData.length} 100
</Box>
</Box>
</Box>
<Box flex={'3 0 0'} h={'100%'} overflow={'auto'} p={2} backgroundColor={'blackAlpha.50'}>
{fileData.slice(0, 100).map((item, index) => (

View File

@@ -1,4 +1,4 @@
import React, { useState, useCallback, useMemo } from 'react';
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
@@ -54,15 +54,17 @@ const SelectFileModal = ({
const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [mode, setMode] = useState<`${TrainingModeEnum}`>(TrainingModeEnum.index);
const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
const [files, setFiles] = useState<{ filename: string; text: string }[]>([
{ filename: '文本1', text: '' }
]);
const [splitRes, setSplitRes] = useState<{
tokens: number;
chunks: string[];
chunks: { filename: string; value: string }[];
successChunks: number;
}>({
tokens: 0,
chunks: [],
successChunks: 0
successChunks: 0,
chunks: []
});
const { openConfirm, ConfirmChild } = useConfirm({
content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${
@@ -78,21 +80,21 @@ const SelectFileModal = ({
files.forEach((file) => {
promise = promise.then(async () => {
const extension = file?.name?.split('.')?.pop()?.toLowerCase();
let text = '';
switch (extension) {
case 'txt':
case 'md':
text = await readTxtContent(file);
break;
case 'pdf':
text = await readPdfContent(file);
break;
case 'doc':
case 'docx':
text = await readDocContent(file);
break;
}
text && setFileTextArr((state) => [text].concat(state));
const text = await (async () => {
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
}
return '';
})();
text && setFiles((state) => [{ filename: file.name, text }].concat(state));
return;
});
});
@@ -115,11 +117,13 @@ const SelectFileModal = ({
// subsection import
let success = 0;
const step = 50;
const step = 100;
for (let i = 0; i < splitRes.chunks.length; i += step) {
const { insertLen } = await postKbDataFromList({
kbId,
data: splitRes.chunks.slice(i, i + step).map((text) => ({ q: text, a: '' })),
data: splitRes.chunks
.slice(i, i + step)
.map((item) => ({ q: item.value, a: '', source: item.filename })),
prompt: `下面是"${prompt || '一段长文本'}"`,
mode
});
@@ -149,26 +153,32 @@ const SelectFileModal = ({
const onclickImport = useCallback(async () => {
setBtnLoading(true);
try {
let promise = Promise.resolve();
const splitRes = await Promise.all(
fileTextArr
.filter((item) => item)
.map((item) =>
splitText_token({
text: item,
...modeMap[mode]
})
)
);
const splitRes = files
.map((item) =>
splitText_token({
text: item.text,
...modeMap[mode]
})
)
.map((item, i) => ({
...item,
filename: files[i].filename
}))
.filter((item) => item.tokens > 0);
setSplitRes({
tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks: splitRes.map((item) => item.chunks).flat(),
chunks: splitRes
.map((item) =>
item.chunks.map((chunk) => ({
filename: item.filename,
value: chunk
}))
)
.flat(),
successChunks: 0
});
await promise;
openConfirm(mutate)();
} catch (error) {
toast({
@@ -177,7 +187,7 @@ const SelectFileModal = ({
});
}
setBtnLoading(false);
}, [fileTextArr, mode, mutate, openConfirm, toast]);
}, [files, mode, mutate, openConfirm, toast]);
return (
<Modal isOpen={true} onClose={onClose} isCentered>
@@ -204,7 +214,7 @@ const SelectFileModal = ({
>
<Box mt={2} px={5} maxW={['100%', '70%']} textAlign={'justify'} color={'blackAlpha.600'}>
{fileExtension} Gpt会自动对文本进行 QA
tokens{fileTextArr.length}
tokens{files.length}
</Box>
{/* 拆分模式 */}
<Flex w={'100%'} px={5} alignItems={'center'} mt={4}>
@@ -235,26 +245,26 @@ const SelectFileModal = ({
)}
{/* 文本内容 */}
<Box flex={'1 0 0'} px={5} h={0} w={'100%'} overflowY={'auto'} mt={4}>
{fileTextArr.slice(0, 100).map((item, i) => (
{files.slice(0, 100).map((item, i) => (
<Box key={i} mb={5}>
<Box mb={1}>{i + 1}</Box>
<Box mb={1}>{item.filename}</Box>
<Textarea
placeholder="文件内容,空内容会自动忽略"
maxLength={-1}
rows={10}
fontSize={'xs'}
whiteSpace={'pre-wrap'}
value={item}
value={item.text}
onChange={(e) => {
setFileTextArr([
...fileTextArr.slice(0, i),
e.target.value,
...fileTextArr.slice(i + 1)
setFiles([
...files.slice(0, i),
{ ...item, text: e.target.value },
...files.slice(i + 1)
]);
}}
onBlur={(e) => {
if (fileTextArr.length > 1 && e.target.value === '') {
setFileTextArr((state) => [...state.slice(0, i), ...state.slice(i + 1)]);
if (files.length > 1 && e.target.value === '') {
setFiles((state) => [...state.slice(0, i), ...state.slice(i + 1)]);
}
}}
/>
@@ -272,7 +282,7 @@ const SelectFileModal = ({
</Button>
<Button
isDisabled={uploading || btnLoading || fileTextArr[0] === ''}
isDisabled={uploading || btnLoading || files[0]?.text === ''}
onClick={onclickImport}
>
{uploading ? (

View File

@@ -61,7 +61,8 @@ export async function generateQA(): Promise<any> {
userId: 1,
kbId: 1,
prompt: 1,
q: 1
q: 1,
source: 1
});
// task preemption
@@ -137,7 +138,10 @@ A2:
// 创建 向量生成 队列
await pushDataToKb({
kbId,
data: responseList,
data: responseList.map((item) => ({
...item,
source: data.source
})),
userId,
mode: TrainingModeEnum.index
});

View File

@@ -57,7 +57,8 @@ export async function generateVector(): Promise<any> {
userId: 1,
kbId: 1,
q: 1,
a: 1
a: 1,
source: 1
});
// task preemption
@@ -91,6 +92,7 @@ export async function generateVector(): Promise<any> {
data: vectors.map((vector, i) => ({
q: dataItems[i].q,
a: dataItems[i].a,
source: data.source,
vector
}))
});

View File

@@ -38,8 +38,9 @@ const TrainingDataSchema = new Schema({
type: String,
default: ''
},
vectorList: {
type: Object
source: {
type: String,
default: ''
}
});

View File

@@ -172,12 +172,14 @@ export const insertKbItem = ({
vector: number[];
q: string;
a: string;
source?: string;
}[];
}) => {
return PgClient.insert('modelData', {
values: data.map((item) => [
{ key: 'user_id', value: userId },
{ key: 'kb_id', value: kbId },
{ key: 'source', value: item.source?.slice(0, 30)?.trim() || '' },
{ key: 'q', value: item.q.replace(/'/g, '"') },
{ key: 'a', value: item.a.replace(/'/g, '"') },
{ key: 'vector', value: `[${item.vector}]` }

View File

@@ -78,6 +78,7 @@ export interface TrainingDataSchema {
prompt: string;
q: string;
a: string;
source: string;
}
export interface ChatSchema {

7
src/types/pg.d.ts vendored
View File

@@ -1,7 +0,0 @@
export interface PgKBDataItemType {
id: string;
q: string;
a: string;
user_id: string;
kb_id: string;
}

View File

@@ -10,8 +10,7 @@ export interface KbDataItemType {
id: string;
q: string; // 提问词
a: string; // 原文
kbId: string;
userId: string;
source: string;
}
export type TextPluginRequestParams = {

View File

@@ -1,6 +1,7 @@
import mammoth from 'mammoth';
import Papa from 'papaparse';
import { getOpenAiEncMap } from './plugin/openai';
import { getErrText } from './tools';
/**
* 读取 txt 文件内容
@@ -145,7 +146,7 @@ export const fileDownload = ({
* slideLen - The size of the before and after Text
* maxLen > slideLen
*/
export const splitText_token = async ({
export const splitText_token = ({
text,
maxLen,
slideLen
@@ -184,8 +185,8 @@ export const splitText_token = async ({
chunks,
tokens
};
} catch (error) {
return Promise.reject(error);
} catch (err) {
throw new Error(getErrText(err));
}
};