mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-27 00:17:31 +00:00
feat: use Tiktokenizer to count tokens
This commit is contained in:
@@ -14,7 +14,6 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { useSelectFile } from '@/hooks/useSelectFile';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import { useConfirm } from '@/hooks/useConfirm';
|
||||
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/file';
|
||||
import { useMutation } from '@tanstack/react-query';
|
||||
@@ -22,6 +21,7 @@ import { postModelDataSplitData } from '@/api/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import Radio from '@/components/Radio';
|
||||
import { splitText } from '@/utils/file';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
|
||||
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
|
||||
|
||||
@@ -29,11 +29,11 @@ const modeMap = {
|
||||
qa: {
|
||||
maxLen: 2800,
|
||||
slideLen: 800,
|
||||
price: 3,
|
||||
price: 4,
|
||||
isPrompt: true
|
||||
},
|
||||
subsection: {
|
||||
maxLen: 1000,
|
||||
maxLen: 800,
|
||||
slideLen: 300,
|
||||
price: 0.4,
|
||||
isPrompt: false
|
||||
@@ -55,19 +55,19 @@ const SelectFileModal = ({
|
||||
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
|
||||
const [mode, setMode] = useState<'qa' | 'subsection'>('qa');
|
||||
const [fileTextArr, setFileTextArr] = useState<string[]>(['']);
|
||||
const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({
|
||||
tokens: 0,
|
||||
chunks: []
|
||||
});
|
||||
const { openConfirm, ConfirmChild } = useConfirm({
|
||||
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,任务讲被终止。'
|
||||
content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${
|
||||
splitRes.chunks.length
|
||||
} 组,大约 ${splitRes.tokens} 个tokens, 约 ${formatPrice(
|
||||
splitRes.tokens * modeMap[mode].price
|
||||
)} 元`
|
||||
});
|
||||
|
||||
const fileText = useMemo(() => {
|
||||
const chunks = fileTextArr.map((item) =>
|
||||
splitText({
|
||||
text: item,
|
||||
...modeMap[mode]
|
||||
})
|
||||
);
|
||||
return chunks.join('');
|
||||
}, [fileTextArr, mode]);
|
||||
const fileText = useMemo(() => fileTextArr.join(''), [fileTextArr]);
|
||||
|
||||
const onSelectFile = useCallback(
|
||||
async (e: File[]) => {
|
||||
@@ -106,18 +106,11 @@ const SelectFileModal = ({
|
||||
|
||||
const { mutate, isLoading } = useMutation({
|
||||
mutationFn: async () => {
|
||||
if (!fileText) return;
|
||||
const chunks = fileTextArr
|
||||
.map((item) =>
|
||||
splitText({
|
||||
text: item,
|
||||
...modeMap[mode]
|
||||
})
|
||||
)
|
||||
.flat();
|
||||
if (splitRes.chunks.length === 0) return;
|
||||
|
||||
await postModelDataSplitData({
|
||||
modelId,
|
||||
chunks,
|
||||
chunks: splitRes.chunks,
|
||||
prompt: `下面是"${prompt || '一段长文本'}"`,
|
||||
mode
|
||||
});
|
||||
@@ -136,6 +129,28 @@ const SelectFileModal = ({
|
||||
}
|
||||
});
|
||||
|
||||
const onclickImport = useCallback(() => {
|
||||
const chunks = fileTextArr
|
||||
.map((item) =>
|
||||
splitText({
|
||||
text: item,
|
||||
...modeMap[mode]
|
||||
})
|
||||
)
|
||||
.flat();
|
||||
// count tokens
|
||||
const tokens = chunks.map((item) =>
|
||||
countChatTokens({ messages: [{ role: 'system', content: item }] })
|
||||
);
|
||||
|
||||
setSplitRes({
|
||||
tokens: tokens.reduce((sum, item) => sum + item, 0),
|
||||
chunks
|
||||
});
|
||||
|
||||
openConfirm(mutate)();
|
||||
}, [fileTextArr, mode, mutate, openConfirm]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={true} onClose={onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
@@ -152,10 +167,9 @@ const SelectFileModal = ({
|
||||
justifyContent={'center'}
|
||||
fontSize={'sm'}
|
||||
>
|
||||
<Box mt={2} px={4} maxW={['100%']} textAlign={'justify'} color={'blackAlpha.600'}>
|
||||
<Box mt={2} px={5} maxW={['100%', '70%']} textAlign={'justify'} color={'blackAlpha.600'}>
|
||||
支持 {fileExtension} 文件。模型会自动对文本进行 QA 拆分,需要较长训练时间,拆分需要消耗
|
||||
tokens,账号余额不足时,未拆分的数据会被删除。当前一共 {encode(fileText).length}{' '}
|
||||
个tokens,大约 {formatPrice(encode(fileText).length * modeMap[mode].price)}元
|
||||
tokens,账号余额不足时,未拆分的数据会被删除。
|
||||
</Box>
|
||||
{/* 拆分模式 */}
|
||||
<Flex w={'100%'} px={5} alignItems={'center'} mt={4}>
|
||||
@@ -217,7 +231,7 @@ const SelectFileModal = ({
|
||||
<Button variant={'outline'} colorScheme={'gray'} mr={3} onClick={onClose}>
|
||||
取消
|
||||
</Button>
|
||||
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
|
||||
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={onclickImport}>
|
||||
确认导入
|
||||
</Button>
|
||||
</Flex>
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import mammoth from 'mammoth';
|
||||
import Papa from 'papaparse';
|
||||
import { encode } from 'gpt-token-utils';
|
||||
import { countChatTokens } from './tools';
|
||||
|
||||
/**
|
||||
* 读取 txt 文件内容
|
||||
@@ -164,7 +164,7 @@ export const splitText = ({
|
||||
const chunks: { sum: number; arr: string[] }[] = [{ sum: 0, arr: [] }];
|
||||
|
||||
for (let i = 0; i < textArr.length; i++) {
|
||||
const tokenLen = encode(textArr[i]).length;
|
||||
const tokenLen = countChatTokens({ messages: [{ role: 'system', content: textArr[i] }] });
|
||||
chunks[chunks.length - 1].sum += tokenLen;
|
||||
chunks[chunks.length - 1].arr.push(textArr[i]);
|
||||
|
||||
@@ -174,7 +174,7 @@ export const splitText = ({
|
||||
const chunk: { sum: number; arr: string[] } = { sum: 0, arr: [] };
|
||||
for (let j = chunks[chunks.length - 1].arr.length - 1; j >= 0; j--) {
|
||||
const chunkText = chunks[chunks.length - 1].arr[j];
|
||||
const tokenLen = encode(chunkText).length;
|
||||
const tokenLen = countChatTokens({ messages: [{ role: 'system', content: chunkText }] });
|
||||
chunk.sum += tokenLen;
|
||||
chunk.arr.unshift(chunkText);
|
||||
|
||||
@@ -185,7 +185,6 @@ export const splitText = ({
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
const result = chunks.map((item) => item.arr.join(''));
|
||||
return result;
|
||||
};
|
||||
|
@@ -1,5 +1,27 @@
|
||||
import crypto from 'crypto';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
|
||||
import Graphemer from 'graphemer';
|
||||
|
||||
const textDecoder = new TextDecoder();
|
||||
const graphemer = new Graphemer();
|
||||
const encMap = {
|
||||
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4': encoding_for_model('gpt-4', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
})
|
||||
};
|
||||
|
||||
/**
|
||||
* copy text data
|
||||
@@ -51,3 +73,60 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
|
||||
}
|
||||
return queryParams.toString();
|
||||
};
|
||||
|
||||
/* 格式化 chat 聊天内容 */
|
||||
function getChatGPTEncodingText(
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
|
||||
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
|
||||
) {
|
||||
const isGpt3 = model === 'gpt-3.5-turbo';
|
||||
|
||||
const msgSep = isGpt3 ? '\n' : '';
|
||||
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
|
||||
|
||||
return [
|
||||
messages
|
||||
.map(({ name = '', role, content }) => {
|
||||
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
|
||||
})
|
||||
.join(msgSep),
|
||||
`<|im_start|>assistant${roleSep}`
|
||||
].join(msgSep);
|
||||
}
|
||||
function text2TokensLen(encoder: Tiktoken, inputText: string) {
|
||||
const encoding = encoder.encode(inputText, 'all');
|
||||
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
|
||||
|
||||
let byteAcc: number[] = [];
|
||||
let tokenAcc: { id: number; idx: number }[] = [];
|
||||
let inputGraphemes = graphemer.splitGraphemes(inputText);
|
||||
|
||||
for (let idx = 0; idx < encoding.length; idx++) {
|
||||
const token = encoding[idx]!;
|
||||
byteAcc.push(...encoder.decode_single_token_bytes(token));
|
||||
tokenAcc.push({ id: token, idx });
|
||||
|
||||
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
|
||||
const graphemes = graphemer.splitGraphemes(segmentText);
|
||||
|
||||
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
|
||||
segments.push({ text: segmentText, tokens: tokenAcc });
|
||||
|
||||
byteAcc = [];
|
||||
tokenAcc = [];
|
||||
inputGraphemes = inputGraphemes.slice(graphemes.length);
|
||||
}
|
||||
}
|
||||
|
||||
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
|
||||
}
|
||||
export const countChatTokens = ({
|
||||
model = 'gpt-3.5-turbo',
|
||||
messages
|
||||
}: {
|
||||
model?: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
||||
}) => {
|
||||
const text = getChatGPTEncodingText(messages, model);
|
||||
return text2TokensLen(encMap[model], text);
|
||||
};
|
||||
|
Reference in New Issue
Block a user