feat: chat content use tiktoken count

This commit is contained in:
archer
2023-04-24 16:46:39 +08:00
parent adbaa8b37b
commit 1f112f7715
23 changed files with 182 additions and 836 deletions

View File

@@ -29,7 +29,6 @@
"eventsource-parser": "^0.1.0",
"formidable": "^2.1.1",
"framer-motion": "^9.0.6",
"gpt-token-utils": "^1.2.0",
"graphemer": "^1.4.0",
"hyperdown": "^2.4.29",
"immer": "^9.0.19",

8
pnpm-lock.yaml generated
View File

@@ -33,7 +33,6 @@ specifiers:
eventsource-parser: ^0.1.0
formidable: ^2.1.1
framer-motion: ^9.0.6
gpt-token-utils: ^1.2.0
graphemer: ^1.4.0
husky: ^8.0.3
hyperdown: ^2.4.29
@@ -86,7 +85,6 @@ dependencies:
eventsource-parser: registry.npmmirror.com/eventsource-parser/0.1.0
formidable: registry.npmmirror.com/formidable/2.1.1
framer-motion: registry.npmmirror.com/framer-motion/9.0.6_biqbaboplfbrettd7655fr4n2y
gpt-token-utils: registry.npmmirror.com/gpt-token-utils/1.2.0
graphemer: registry.npmmirror.com/graphemer/1.4.0
hyperdown: registry.npmmirror.com/hyperdown/2.4.29
immer: registry.npmmirror.com/immer/9.0.19
@@ -7668,12 +7666,6 @@ packages:
get-intrinsic: registry.npmmirror.com/get-intrinsic/1.2.0
dev: true
registry.npmmirror.com/gpt-token-utils/1.2.0:
resolution: {integrity: sha512-s8twaU38UE2Vp65JhQEjz8qvWhWY8KZYvmvYHapxlPT03Ok35Clq+gm9eE27wQILdFisseMVRSiC5lJR9GBklA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gpt-token-utils/-/gpt-token-utils-1.2.0.tgz}
name: gpt-token-utils
version: 1.2.0
dev: false
registry.npmmirror.com/graceful-fs/4.2.10:
resolution: {integrity: sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/graceful-fs/-/graceful-fs-4.2.10.tgz}
name: graceful-fs

View File

@@ -5,24 +5,28 @@ export enum ModelDataStatusEnum {
waiting = 'waiting'
}
export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
VECTOR = 'text-embedding-ada-002'
export const embeddingModel = 'text-embedding-ada-002';
export enum ChatModelEnum {
'GPT35' = 'gpt-3.5-turbo',
'GPT4' = 'gpt-4',
'GPT432k' = 'gpt-4-32k'
}
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR]: 'text-embedding-ada-002'
export enum ModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT'
}
export const Model2ChatModelMap: Record<`${ModelNameEnum}`, `${ChatModelEnum}`> = {
[ModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo'
};
export type ModelConstantsData = {
icon: 'model' | 'dbModel';
name: string;
model: `${ChatModelNameEnum}`;
model: `${ModelNameEnum}`;
trainName: string; // 空字符串代表不能训练
maxToken: number;
contextMaxToken: number;
maxTemperature: number;
price: number; // 多少钱 / 1token单位: 0.00001元
@@ -32,20 +36,18 @@ export const modelList: ModelConstantsData[] = [
{
icon: 'model',
name: 'chatGPT',
model: ChatModelNameEnum.GPT35,
model: ModelNameEnum.GPT35,
trainName: '',
maxToken: 4000,
contextMaxToken: 7000,
contextMaxToken: 4096,
maxTemperature: 1.5,
price: 3
},
{
icon: 'dbModel',
name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT,
model: ModelNameEnum.VECTOR_GPT,
trainName: 'vector',
maxToken: 4000,
contextMaxToken: 7000,
contextMaxToken: 4096,
maxTemperature: 1,
price: 3
}
@@ -133,8 +135,8 @@ export const defaultModel: ModelSchema = {
},
service: {
trainId: '',
chatModel: ChatModelNameEnum.GPT35,
modelName: ChatModelNameEnum.GPT35
chatModel: ModelNameEnum.GPT35,
modelName: ModelNameEnum.GPT35
},
security: {
domain: ['*'],

View File

@@ -2,7 +2,6 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { httpsAgent, openaiChatFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
@@ -64,42 +63,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
// console.log({
// model: model.service.chatModel,
// temperature: temperature,
// // max_tokens: modelConstantsData.maxToken,
// messages: formatPrompts,
// frequency_penalty: 0.5, // 越大,重复内容越少
// presence_penalty: -0.5, // 越大,越容易出现新内容
// stream: true,
// stop: ['.!?。']
// });
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true,
@@ -121,7 +101,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
stream,
chatResponse
});
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
@@ -129,7 +108,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {

View File

@@ -2,10 +2,8 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/auth';
import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import {
modelList,
@@ -105,9 +103,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
value: model.systemPrompt
});
} else {
// 有匹配情况下,添加知识库内容。
// 系统提示词过滤,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
// 有匹配情况下,system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
@@ -124,21 +126,13 @@ ${
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
@@ -146,9 +140,8 @@ ${
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
@@ -170,14 +163,13 @@ ${
chatResponse
});
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
// jsonRes(res);
} catch (err: any) {

View File

@@ -4,7 +4,7 @@ import { connectToDatabase, DataItem, Data } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { generateAbstract } from '@/service/events/generateAbstract';
import { encode } from 'gpt-token-utils';
import { countChatTokens } from '@/utils/tools';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -34,7 +34,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
const tokens = countChatTokens({ messages: [{ role: 'system', content: splitText }] });
if (tokens >= 780) {
dataItems.push({
userId,

View File

@@ -3,14 +3,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model';
import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { name, serviceModelName } = req.body as {
name: string;
serviceModelName: `${ChatModelNameEnum}`;
serviceModelName: `${ModelNameEnum}`;
};
const { authorization } = req.headers;
@@ -48,7 +48,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running,
service: {
trainId: '',
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
});

View File

@@ -75,21 +75,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
@@ -99,9 +91,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
@@ -133,14 +124,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
}
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
modelName: model.service.modelName,
userId,
text: promptsContent + responseContent
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {

View File

@@ -3,15 +3,14 @@ import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { authOpenApiKey } from '@/service/utils/tools';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
ChatModelNameEnum,
ModelNameEnum,
modelList,
ChatModelNameMap,
ModelVectorSearchModeMap
ModelVectorSearchModeMap,
ChatModelEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
@@ -60,9 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('找不到模型');
}
const modelConstantsData = modelList.find(
(item) => item.model === ChatModelNameEnum.VECTOR_GPT
);
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT);
if (!modelConstantsData) {
throw new Error('模型已下架');
}
@@ -74,7 +71,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 请求一次 chatgpt 拆解需求
const promptResponse = await chatAPI.createChatCompletion(
{
model: ChatModelNameMap[ChatModelNameEnum.GPT35],
model: ChatModelEnum.GPT35,
temperature: 0,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -122,7 +119,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
]
},
{
timeout: 120000,
timeout: 180000,
httpsAgent: httpsAgent(true)
}
);
@@ -163,30 +160,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// textArr 筛选,最多 2500 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 2500);
// system 筛选,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
});
// 控制 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 控制上下文 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
@@ -195,13 +188,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
{
model: model.service.chatModel,
temperature,
messages: formatPrompts,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream
},
{
timeout: 120000,
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
httpsAgent: httpsAgent(true)
}
@@ -228,13 +221,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
console.log('laf gpt done. time:', `${(Date.now() - startTime) / 1000}s`);
const promptsContent = formatPrompts.map((item) => item.content).join('');
pushChatBill({
isPay: true,
modelName: model.service.modelName,
userId,
text: promptsContent + responseContent
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {

View File

@@ -126,8 +126,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
});
} else {
// 有匹配或者低匹配度模式情况下,添加知识库内容。
// 系统提示词过滤,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
@@ -144,21 +148,13 @@ ${
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
@@ -166,14 +162,14 @@ ${
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
messages: formatPrompts,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream
},
{
timeout: 120000,
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
httpsAgent: httpsAgent(true)
}
@@ -198,12 +194,11 @@ ${
});
}
const promptsContent = formatPrompts.map((item) => item.content).join('');
pushChatBill({
isPay: true,
modelName: model.service.modelName,
userId,
text: promptsContent + responseContent
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
// jsonRes(res);
} catch (err: any) {

View File

@@ -21,7 +21,7 @@ import {
import { useToast } from '@/hooks/useToast';
import { useScreen } from '@/hooks/useScreen';
import { useQuery } from '@tanstack/react-query';
import { ChatModelNameEnum } from '@/constants/model';
import { ModelNameEnum } from '@/constants/model';
import dynamic from 'next/dynamic';
import { useGlobalStore } from '@/store/global';
import { useCopyData } from '@/utils/tools';
@@ -178,8 +178,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
const gptChatPrompt = useCallback(
async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt'
[ModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt'
};
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');

View File

@@ -1,97 +0,0 @@
import React, { useState } from 'react';
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalFooter,
ModalBody,
ModalCloseButton,
Button,
Input,
Select,
FormControl,
FormErrorMessage
} from '@chakra-ui/react';
import { postData } from '@/api/data';
import { useMutation } from '@tanstack/react-query';
import { useForm, SubmitHandler } from 'react-hook-form';
import { DataType } from '@/types/data';
import { DataTypeTextMap } from '@/constants/data';
export interface CreateDataProps {
name: string;
type: DataType;
}
const CreateDataModal = ({
onClose,
onSuccess
}: {
onClose: () => void;
onSuccess: () => void;
}) => {
const [inputVal, setInputVal] = useState('');
const {
getValues,
register,
handleSubmit,
formState: { errors }
} = useForm<CreateDataProps>({
defaultValues: {
name: '',
type: 'abstract'
}
});
const { isLoading, mutate } = useMutation({
mutationFn: (e: CreateDataProps) => postData(e),
onSuccess() {
onSuccess();
onClose();
}
});
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl mb={8} isInvalid={!!errors.name}>
<Input
placeholder="数据集名称"
{...register('name', {
required: '数据集名称不能为空'
})}
/>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.name && errors.name.message}
</FormErrorMessage>
</FormControl>
<FormControl>
<Select placeholder="数据集类型" {...register('type', {})}>
{Object.entries(DataTypeTextMap).map(([key, value]) => (
<option key={key} value={key}>
{value}
</option>
))}
</Select>
</FormControl>
</ModalBody>
<ModalFooter>
<Button colorScheme={'gray'} onClick={onClose}>
</Button>
<Button ml={3} isLoading={isLoading} onClick={handleSubmit(mutate as any)}>
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};
export default CreateDataModal;

View File

@@ -1,229 +0,0 @@
import React, { useState, useCallback } from 'react';
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalFooter,
ModalBody,
ModalCloseButton,
Button,
Box,
Flex,
Textarea
} from '@chakra-ui/react';
import { useTabs } from '@/hooks/useTabs';
import { useConfirm } from '@/hooks/useConfirm';
import { useSelectFile } from '@/hooks/useSelectFile';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/file';
import { postSplitData } from '@/api/data';
import { useMutation } from '@tanstack/react-query';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import { formatPrice } from '@/utils/user';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode } from 'gpt-token-utils';
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
const ImportDataModal = ({
dataId,
onClose,
onSuccess
}: {
dataId: string;
onClose: () => void;
onSuccess: () => void;
}) => {
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认提交生成任务?该任务无法终止!'
});
const { toast } = useToast();
const { setIsLoading, Loading } = useLoading();
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const { tabs, activeTab, setActiveTab } = useTabs({
tabs: [
{ id: 'text', label: '文本' },
{ id: 'doc', label: '文件' }
// { id: 'url', label: '链接' }
]
});
const [textInput, setTextInput] = useState('');
const [fileText, setFileText] = useState('');
const { mutate: handleClickSubmit, isLoading } = useMutation({
mutationFn: async () => {
let text = '';
if (activeTab === 'text') {
text = textInput;
} else if (activeTab === 'doc') {
text = fileText;
} else if (activeTab === 'url') {
}
if (!text) return;
return postSplitData(dataId, text);
},
onSuccess() {
toast({
title: '任务提交成功',
status: 'success'
});
onClose();
onSuccess();
},
onError(err: any) {
toast({
title: err?.message || '提交任务异常',
status: 'error'
});
}
});
const onSelectFile = useCallback(
async (e: File[]) => {
setIsLoading(true);
try {
const fileTexts = (
await Promise.all(
e.map((file) => {
// @ts-ignore
const extension = file?.name?.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
default:
return '';
}
})
)
)
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
title: typeof error === 'string' ? error : '解析文件失败',
status: 'error'
});
}
setIsLoading(false);
},
[setIsLoading, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent position={'relative'} maxW={['90vw', '800px']}>
<ModalHeader>
QA
<Box ml={2} as={'span'} fontSize={'sm'} color={'blackAlpha.600'}>
{formatPrice(
modelList.find((item) => item.model === ChatModelNameEnum.GPT35)?.price || 0,
1000
)}
/1K tokens
</Box>
</ModalHeader>
<ModalCloseButton />
<ModalBody display={'flex'}>
<Box>
{tabs.map((item) => (
<Button
key={item.id}
display={'block'}
variant={activeTab === item.id ? 'solid' : 'outline'}
_notLast={{
mb: 3
}}
onClick={() => setActiveTab(item.id)}
>
{item.label}
</Button>
))}
</Box>
<Box flex={'1 0 0'} w={0} ml={3} minH={'200px'}>
{activeTab === 'text' && (
<>
<Textarea
h={'100%'}
maxLength={-1}
value={textInput}
placeholder={'请粘贴或输入需要处理的文本'}
onChange={(e) => setTextInput(e.target.value)}
/>
<Box mt={2}>
{textInput.length} {encode(textInput).length} tokens
</Box>
</>
)}
{activeTab === 'doc' && (
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
border={'1px solid '}
borderColor={'blackAlpha.200'}
borderRadius={'md'}
fontSize={'sm'}
>
<Button onClick={onOpen}></Button>
<Box mt={2}> {fileExtension} </Box>
{fileText && (
<>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
maxH={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</>
)}
</Flex>
)}
</Box>
</ModalBody>
<ModalFooter>
<Button colorScheme={'gray'} onClick={onClose}>
</Button>
<Button
ml={3}
isLoading={isLoading}
isDisabled={!textInput && !fileText}
onClick={openConfirm(handleClickSubmit)}
>
</Button>
</ModalFooter>
<Loading />
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default ImportDataModal;

View File

@@ -1,67 +0,0 @@
import React from 'react';
import { Box, Card } from '@chakra-ui/react';
import ScrollData from '@/components/ScrollData';
import { getDataItems } from '@/api/data';
import { usePaging } from '@/hooks/usePaging';
import type { DataItemSchema } from '@/types/mongoSchema';
const DataDetail = ({ dataName, dataId }: { dataName: string; dataId: string }) => {
const {
nextPage,
isLoadAll,
requesting,
data: dataItems
} = usePaging<DataItemSchema>({
api: getDataItems,
pageSize: 10,
params: {
dataId
}
});
return (
<Card py={4} h={'100%'} display={'flex'} flexDirection={'column'}>
<Box px={6} fontSize={'xl'} fontWeight={'bold'}>
{dataName}
</Box>
<ScrollData
flex={'1 0 0'}
h={0}
px={6}
mt={3}
isLoadAll={isLoadAll}
requesting={requesting}
nextPage={nextPage}
fontSize={'xs'}
whiteSpace={'pre-wrap'}
>
{dataItems.map((item) => (
<Box key={item._id}>
{item.result.map((result, i) => (
<Box key={i} mb={3}>
{item.type === 'QA' && (
<>
<Box fontWeight={'bold'}>Q: {result.q}</Box>
<Box>A: {result.a}</Box>
</>
)}
{item.type === 'abstract' && <Box fontSize={'sm'}>{result.abstract}</Box>}
</Box>
))}
</Box>
))}
</ScrollData>
</Card>
);
};
export default DataDetail;
export async function getServerSideProps(context: any) {
return {
props: {
dataName: context.query?.dataName || '',
dataId: context.query?.dataId || ''
}
};
}

View File

@@ -1,235 +0,0 @@
import React, { useState, useCallback } from 'react';
import {
Card,
Box,
Flex,
Button,
Table,
Thead,
Tbody,
Tr,
Th,
Td,
TableContainer,
useDisclosure,
Input,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import { getDataList, updateDataName, delData, getDataItems } from '@/api/data';
import type { DataListItem } from '@/types/data';
import dayjs from 'dayjs';
import dynamic from 'next/dynamic';
import { useRouter } from 'next/router';
import { useConfirm } from '@/hooks/useConfirm';
import { useRequest } from '@/hooks/useRequest';
import { DataItemSchema } from '@/types/mongoSchema';
import { DataTypeTextMap } from '@/constants/data';
import { customAlphabet } from 'nanoid';
import { useQuery } from '@tanstack/react-query';
const nanoid = customAlphabet('.,', 1);
const CreateDataModal = dynamic(() => import('./components/CreateDataModal'));
const ImportDataModal = dynamic(() => import('./components/ImportDataModal'));
export type ExportDataType = 'jsonl' | 'txt';
const DataList = () => {
const router = useRouter();
const [ImportDataId, setImportDataId] = useState<string>();
const { openConfirm, ConfirmChild } = useConfirm({
content: '删除数据集,将删除里面的所有内容,请确认!'
});
const {
isOpen: isOpenCreateDataModal,
onOpen: onOpenCreateDataModal,
onClose: onCloseCreateDataModal
} = useDisclosure();
const { data: dataList = [], refetch } = useQuery(['getDataList'], getDataList, {
refetchInterval: 10000
});
const { mutate: handleDelData, isLoading: isDeleting } = useRequest({
mutationFn: (dataId: string) => delData(dataId),
successToast: '删除数据集成功',
errorToast: '删除数据集异常',
onSuccess() {
refetch();
}
});
const { mutate: handleExportData, isLoading: isExporting } = useRequest({
mutationFn: async ({ data, type }: { data: DataListItem; type: ExportDataType }) => ({
type,
data: await getDataItems({ dataId: data._id, pageNum: 1, pageSize: data.totalData }).then(
(res) => res.data
)
}),
successToast: '导出数据集成功',
errorToast: '导出数据集异常',
onSuccess(res: { type: ExportDataType; data: DataItemSchema[] }) {
// 合并数据
const data = res.data.map((item) => item.result).flat();
let text = '';
// 生成 jsonl
data.forEach((item) => {
if (res.type === 'jsonl' && item.q && item.a) {
const result = JSON.stringify({
prompt: `${item.q.toLocaleLowerCase()}${nanoid()}</s>`,
completion: ` ${item.a}###`
});
text += `${result}\n`;
} else if (res.type === 'txt' && item.abstract) {
text += `${item.abstract}\n`;
}
});
// 去掉最后一个 \n
text = text.substring(0, text.length - 1);
// 导出为文件
const blob = new Blob([text], { type: 'application/json;charset=utf-8' });
// 创建下载链接
const downloadLink = document.createElement('a');
downloadLink.href = window.URL.createObjectURL(blob);
downloadLink.download = `data.${res.type}`;
// 添加链接到页面并触发下载
document.body.appendChild(downloadLink);
downloadLink.click();
document.body.removeChild(downloadLink);
}
});
return (
<Box display={['block', 'flex']} flexDirection={'column'} h={'100%'}>
<Card px={6} py={4}>
<Flex>
<Box flex={1} mr={1}>
<Box fontSize={'xl'} fontWeight={'bold'}>
</Box>
<Box fontSize={'xs'} color={'blackAlpha.600'}>
QA
</Box>
</Box>
<Button variant={'outline'} onClick={onOpenCreateDataModal}>
</Button>
</Flex>
</Card>
{/* 数据表 */}
<TableContainer
mt={3}
flex={'1 0 0'}
h={['auto', '0']}
overflowY={'auto'}
px={6}
py={4}
backgroundColor={'white'}
borderRadius={'md'}
boxShadow={'base'}
>
<Table>
<Thead>
<Tr>
<Th></Th>
<Th></Th>
<Th></Th>
<Th> / </Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{dataList.map((item, i) => (
<Tr key={item._id}>
<Td>
<Input
minW={'150px'}
placeholder="请输入数据集名称"
defaultValue={item.name}
size={'sm'}
onBlur={(e) => {
if (!e.target.value || e.target.value === item.name) return;
updateDataName(item._id, e.target.value);
}}
/>
</Td>
<Td>{DataTypeTextMap[item.type || 'QA']}</Td>
<Td>{dayjs(item.createTime).format('YYYY/MM/DD HH:mm')}</Td>
<Td>
{item.trainingData} / {item.totalData}
</Td>
<Td>
<Button
size={'sm'}
variant={'outline'}
colorScheme={'gray'}
mr={2}
onClick={() =>
router.push(`/data/detail?dataId=${item._id}&dataName=${item.name}`)
}
>
</Button>
<Button
size={'sm'}
variant={'outline'}
mr={2}
onClick={() => setImportDataId(item._id)}
>
</Button>
{/* <Menu>
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
导出
</MenuButton>
<MenuList>
{item.type === 'QA' && (
<MenuItem onClick={() => handleExportData({ data: item, type: 'jsonl' })}>
jsonl
</MenuItem>
)}
{item.type === 'abstract' && (
<MenuItem onClick={() => handleExportData({ data: item, type: 'txt' })}>
txt
</MenuItem>
)}
</MenuList>
</Menu> */}
<Button
size={'sm'}
colorScheme={'red'}
isLoading={isDeleting}
onClick={openConfirm(() => handleDelData(item._id))}
>
</Button>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
{ImportDataId && (
<ImportDataModal
dataId={ImportDataId}
onClose={() => setImportDataId(undefined)}
onSuccess={refetch}
/>
)}
{isOpenCreateDataModal && (
<CreateDataModal onClose={onCloseCreateDataModal} onSuccess={refetch} />
)}
<ConfirmChild />
</Box>
);
};
export default DataList;

View File

@@ -13,15 +13,11 @@ import {
Textarea
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { customAlphabet } from 'nanoid';
import { encode } from 'gpt-token-utils';
import { useConfirm } from '@/hooks/useConfirm';
import { useMutation } from '@tanstack/react-query';
import { postModelDataSplitData, getWebContent } from '@/api/model';
import { formatPrice } from '@/utils/user';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const SelectUrlModal = ({
onClose,
onSuccess,
@@ -106,9 +102,6 @@ const SelectUrlModal = ({
QA tokens
</Box>
<Box mt={2}>
{encode(webText).length} tokens {formatPrice(encode(webText).length * 3)}
</Box>
<Flex w={'100%'} alignItems={'center'} my={4}>
<Box flex={'0 0 70px'}></Box>
<Input

View File

@@ -4,7 +4,7 @@ import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model';
import { ChatModelEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
export async function generateAbstract(next = false): Promise<any> {
@@ -68,7 +68,7 @@ export async function generateAbstract(next = false): Promise<any> {
// 请求 chatgpt 获取摘要
const abstractResponse = await chatAPI.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
model: ChatModelEnum.GPT35,
temperature: 0.8,
n: 1,
messages: [

View File

@@ -3,7 +3,7 @@ import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai';
import { ChatModelNameEnum } from '@/constants/model';
import { ChatModelEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { openaiError2 } from '../errorCode';
@@ -84,7 +84,7 @@ A2:
chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
model: ChatModelEnum.GPT35,
temperature: 0.8,
n: 1,
messages: [

View File

@@ -1,27 +1,34 @@
import { connectToDatabase, Bill, User } from '../mongo';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode } from 'gpt-token-utils';
import {
modelList,
ChatModelEnum,
ModelNameEnum,
Model2ChatModelMap,
embeddingModel
} from '@/constants/model';
import { BillTypeEnum } from '@/constants/user';
import type { DataType } from '@/types/data';
import { countChatTokens } from '@/utils/tools';
export const pushChatBill = async ({
isPay,
modelName,
userId,
chatId,
text
messages
}: {
isPay: boolean;
modelName: string;
modelName: `${ModelNameEnum}`;
userId: string;
chatId?: '' | string;
text: string;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
}) => {
let billId;
let billId = '';
try {
// 计算 token 数量
const tokens = Math.floor(encode(text).length * 0.75);
const tokens = countChatTokens({ model: Model2ChatModelMap[modelName] as any, messages });
const text = messages.map((item) => item.content).join('');
console.log(
`chat generate success. text len: ${text.length}. token len: ${tokens}. pay:${isPay}`
@@ -88,7 +95,7 @@ export const pushSplitDataBill = async ({
if (isPay) {
try {
// 获取模型单价格, 都是用 gpt35 拆分
const modelItem = modelList.find((item) => item.model === ChatModelNameEnum.GPT35);
const modelItem = modelList.find((item) => item.model === ChatModelEnum.GPT35);
const unitPrice = modelItem?.price || 3;
// 计算价格
const price = unitPrice * tokenLen;
@@ -97,7 +104,7 @@ export const pushSplitDataBill = async ({
const res = await Bill.create({
userId,
type,
modelName: ChatModelNameEnum.GPT35,
modelName: ChatModelEnum.GPT35,
textLen: text.length,
tokenLen,
price
@@ -149,7 +156,7 @@ export const pushGenerateVectorBill = async ({
const res = await Bill.create({
userId,
type: BillTypeEnum.vector,
modelName: ChatModelNameEnum.VECTOR,
modelName: embeddingModel,
textLen: text.length,
tokenLen,
price

View File

@@ -5,7 +5,7 @@ import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from './tools';
import { User } from '../models/user';
import { formatPrice } from '@/utils/user';
import { ChatModelNameEnum } from '@/constants/model';
import { embeddingModel } from '@/constants/model';
import { pushGenerateVectorBill } from '../events/pushBill';
/* 获取用户 api 的 openai 信息 */
@@ -80,7 +80,7 @@ export const openaiCreateEmbedding = async ({
const res = await chatAPI
.createEmbedding(
{
model: ChatModelNameEnum.VECTOR,
model: embeddingModel,
input: text
},
{
@@ -134,11 +134,11 @@ export const gpt35StreamResponse = ({
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
// console.log('content:', content);
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
if (!stream.destroyed && content) {
stream.push(content.replace(/\n/g, '<br/>'));
}
} catch (error) {
error;
}

View File

@@ -2,10 +2,12 @@ import type { NextApiRequest } from 'next';
import crypto from 'crypto';
import jwt from 'jsonwebtoken';
import { ChatItemType } from '@/types/chat';
import { encode } from 'gpt-token-utils';
import { OpenApi, User } from '../mongo';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { countChatTokens } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatModelEnum } from '@/constants/model';
/* 密码加密 */
export const hashPassword = (psw: string) => {
@@ -86,8 +88,16 @@ export const authOpenApiKey = async (req: NextApiRequest) => {
export const httpsAgent = (fast: boolean) =>
fast ? global.httpsAgentFast : global.httpsAgentNormal;
/* tokens 截断 */
export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) => {
/* 聊天内容 tokens 截断 */
export const openaiChatFilter = ({
model,
prompts,
maxTokens
}: {
model: `${ChatModelEnum}`;
prompts: ChatItemType[];
maxTokens: number;
}) => {
const formatPrompts = prompts.map((item) => ({
obj: item.obj,
value: item.value
@@ -97,41 +107,64 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
.trim()
}));
let res: ChatItemType[] = [];
let chats: ChatItemType[] = [];
let systemPrompt: ChatItemType | null = null;
// System 词保留
if (formatPrompts[0]?.obj === 'SYSTEM') {
systemPrompt = formatPrompts.shift() as ChatItemType;
maxTokens -= encode(formatPrompts[0].value).length;
}
// 从后往前截取
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
const tokens = encode(formatPrompts[i].value).length;
res.unshift(formatPrompts[i]);
chats.unshift(formatPrompts[i]);
messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({
role: map[item.obj],
content: item.value
}));
const tokens = countChatTokens({
model,
messages
});
/* 整体 tokens 超出范围 */
if (tokens >= maxTokens) {
break;
}
maxTokens -= tokens;
}
return systemPrompt ? [systemPrompt, ...res] : res;
return messages;
};
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
export const systemPromptFilter = ({
model,
prompts,
maxTokens
}: {
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
prompts: string[];
maxTokens: number;
}) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
const prompt = prompts[i].replace(/\n+/g, '\n');
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
if (tokens >= maxTokens) {
break;
}

View File

@@ -2,8 +2,9 @@ import type { ChatItemType } from './chat';
import {
ModelStatusEnum,
TrainingStatusEnum,
ChatModelNameEnum,
ModelVectorSearchModeEnum
ModelNameEnum,
ModelVectorSearchModeEnum,
ChatModelEnum
} from '@/constants/model';
import type { DataType } from './data';
@@ -45,8 +46,8 @@ export interface ModelSchema {
};
service: {
trainId: string; // 训练的模型训练后就是训练的模型id
chatModel: string; // 聊天时用的模型,训练后就是训练的模型
modelName: `${ChatModelNameEnum}`; // 底层模型名称,不会变
chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型
modelName: `${ModelNameEnum}`; // 底层模型名称,不会变
};
security: {
domain: string[];

View File

@@ -2,6 +2,7 @@ import crypto from 'crypto';
import { useToast } from '@/hooks/useToast';
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
import Graphemer from 'graphemer';
import { ChatModelEnum } from '@/constants/model';
const textDecoder = new TextDecoder();
const graphemer = new Graphemer();
@@ -124,7 +125,7 @@ export const countChatTokens = ({
model = 'gpt-3.5-turbo',
messages
}: {
model?: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
model?: `${ChatModelEnum}`;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
}) => {
const text = getChatGPTEncodingText(messages, model);