perf: 减少聊天内容配置,自动截断上下文

This commit is contained in:
archer
2023-03-28 00:07:32 +08:00
parent 7fb6f62cf6
commit 7a6d0ea650
14 changed files with 144 additions and 172 deletions

View File

@@ -7,9 +7,7 @@ export type InitChatResponse = {
name: string;
avatar: string;
intro: string;
secret: ModelSchema.secret;
chatModel: ModelSchema.service.chatModel; // 对话模型名
modelName: ModelSchema.service.modelName; // 底层模型
history: ChatItemType[];
isExpiredTime: boolean;
};

View File

@@ -7,8 +7,7 @@ import { useQuery } from '@tanstack/react-query';
const unAuthPage: { [key: string]: boolean } = {
'/': true,
'/login': true,
'/chat': true
'/login': true
};
const Auth = ({ children }: { children: JSX.Element }) => {

View File

@@ -44,6 +44,9 @@ export const introPage = `
`;
export const chatProblem = `
**内容长度**
单次最长 4000 tokens, 上下文最长 8000 tokens, 上下文超长时会被截断。
**模型问题**
一般情况下,请直接选择 chatGPT 模型,价格低效果好。

View File

@@ -27,17 +27,17 @@ export const modelList: ModelConstantsData[] = [
trainedMaxToken: 2000,
maxTemperature: 2,
price: 3
},
{
serviceCompany: 'openai',
name: 'GPT3',
model: ChatModelNameEnum.GPT3,
trainName: 'davinci',
maxToken: 4000,
trainedMaxToken: 2000,
maxTemperature: 2,
price: 30
}
// {
// serviceCompany: 'openai',
// name: 'GPT3',
// model: ChatModelNameEnum.GPT3,
// trainName: 'davinci',
// maxToken: 4000,
// trainedMaxToken: 2000,
// maxTemperature: 2,
// price: 30
// }
];
export enum TrainingStatusEnum {

View File

@@ -10,6 +10,7 @@ import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiChatFilter } from '@/service/utils/tools';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -32,6 +33,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
@@ -46,12 +48,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [...chat.content, prompt];
// 上下文长度过滤
const maxContext = model.security.contextMaxLen;
const filterPrompts =
prompts.length > maxContext ? prompts.slice(prompts.length - maxContext) : prompts;
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
});
}
// 格式化文本内容
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, 7500);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
@@ -63,15 +71,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
content: item.value
})
);
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
formatPrompts.unshift({
role: 'system',
content: model.systemPrompt
});
}
// console.log(formatPrompts);
// 计算温度
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {

View File

@@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const { authorization } = req.headers;
if (!authorization) {
throw new Error('无权操作');
throw new Error('无权生成对话');
}
if (!modelId) {

View File

@@ -3,10 +3,14 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, Chat } from '@/service/mongo';
import type { ChatPopulate } from '@/types/mongoSchema';
import type { InitChatResponse } from '@/api/response/chat';
import { authToken } from '@/service/utils/tools';
/* 获取我的模型 */
/* 初始化我的聊天框,需要身份验证 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { authorization } = req.headers;
const userId = await authToken(authorization);
const { chatId } = req.query as { chatId: string };
if (!chatId) {
@@ -16,7 +20,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
await connectToDatabase();
// 获取 chat 数据
const chat = await Chat.findById<ChatPopulate>(chatId).populate({
const chat = await Chat.findOne<ChatPopulate>({
_id: chatId,
userId
}).populate({
path: 'modelId',
options: {
strictPopulate: false
@@ -27,31 +34,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('聊天框不存在');
}
if (chat.loadAmount > 0) {
await Chat.updateOne(
{
_id: chat._id
},
{
$inc: { loadAmount: -1 }
}
);
}
// filter 掉被 deleted 的内容
chat.content = chat.content.filter((item) => item.deleted !== true);
const model = chat.modelId;
jsonRes<InitChatResponse>(res, {
code: 201,
data: {
chatId: chat._id,
isExpiredTime: chat.loadAmount === 0 || chat.expiredTime <= Date.now(),
modelId: model._id,
name: model.name,
avatar: model.avatar,
intro: model.intro,
secret: model.security,
modelName: model.service.modelName,
chatModel: model.service.chatModel,
history: chat.content

View File

@@ -231,12 +231,12 @@ const SlideBar = ({
</>
</RenderButton>
<RenderButton onClick={onOpenShare}>
{/* <RenderButton onClick={onOpenShare}>
<>
<MyIcon name="share" fill={'white'} w={'16px'} h={'16px'} mr={4} />
分享
</>
</RenderButton>
</RenderButton> */}
<RenderButton onClick={() => router.push('/number/setting')}>
<>
<MyIcon name="pay" fill={'white'} w={'16px'} h={'16px'} mr={4} />

View File

@@ -14,7 +14,6 @@ import {
Textarea,
Box,
Flex,
Button,
useDisclosure,
Drawer,
DrawerOverlay,
@@ -36,8 +35,8 @@ import { useCopyData } from '@/utils/tools';
import { streamFetch } from '@/api/fetch';
import SlideBar from './components/SlideBar';
import Empty from './components/Empty';
import { getToken } from '@/utils/user';
import Icon from '@/components/Icon';
import { encode } from 'gpt-token-utils';
const Markdown = dynamic(() => import('@/components/Markdown'));
@@ -60,11 +59,9 @@ const Chat = ({ chatId }: { chatId: string }) => {
name: '',
avatar: '',
intro: '',
secret: {},
chatModel: '',
modelName: '',
history: [],
isExpiredTime: false
history: []
}); // 聊天框整体数据
const [inputVal, setInputVal] = useState(''); // 输入的内容
@@ -72,15 +69,6 @@ const Chat = ({ chatId }: { chatId: string }) => {
() => chatData.history[chatData.history.length - 1]?.status === 'loading',
[chatData.history]
);
const chatWindowError = useMemo(() => {
if (chatData.isExpiredTime) {
return {
text: '聊天框已过期'
};
}
return '';
}, [chatData]);
const { copyData } = useCopyData();
const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore();
@@ -125,31 +113,6 @@ const Chat = ({ chatId }: { chatId: string }) => {
onCloseSlider();
}, [chatData, onCloseSlider, router, toast]);
// gpt3 方法
const gpt3ChatPrompt = useCallback(
async (newChatList: ChatSiteItemType[]) => {
// 请求内容
const response = await postGPT3SendPrompt({
prompt: newChatList,
chatId: chatId as string
});
// 更新 AI 的内容
setChatData((state) => ({
...state,
history: state.history.map((item, index) => {
if (index !== state.history.length - 1) return item;
return {
...item,
status: 'finish',
value: response
};
})
}));
},
[chatId]
);
// gpt 对话
const gptChatPrompt = useCallback(
async (prompts: ChatSiteItemType) => {
@@ -476,83 +439,74 @@ const Chat = ({ chatId }: { chatId: string }) => {
</Box>
{/* 发送区 */}
<Box m={media('20px auto', '0 auto')} w={'100%'} maxW={media('min(750px, 100%)', 'auto')}>
{!!chatWindowError ? (
<Box textAlign={'center'}>
<Box color={'red'}>{chatWindowError.text}</Box>
<Flex py={5} justifyContent={'center'}>
{getToken() && <Button onClick={resetChat}></Button>}
</Flex>
<Box
py={5}
position={'relative'}
boxShadow={`0 0 15px rgba(0,0,0,0.1)`}
border={media('1px solid', '0')}
borderColor={useColorModeValue('gray.200', 'gray.700')}
borderRadius={['none', 'md']}
backgroundColor={useColorModeValue('white', 'gray.700')}
>
{/* 输入框 */}
<Textarea
ref={TextareaDom}
w={'100%'}
pr={'45px'}
py={0}
border={'none'}
_focusVisible={{
border: 'none'
}}
placeholder="提问"
resize={'none'}
value={inputVal}
rows={1}
height={'22px'}
lineHeight={'22px'}
maxHeight={'150px'}
maxLength={-1}
overflowY={'auto'}
color={useColorModeValue('blackAlpha.700', 'white')}
onChange={(e) => {
const textarea = e.target;
setInputVal(textarea.value);
textarea.style.height = textareaMinH;
textarea.style.height = `${textarea.scrollHeight}px`;
}}
onKeyDown={(e) => {
// 触发快捷发送
if (isPc && e.keyCode === 13 && !e.shiftKey) {
sendPrompt();
e.preventDefault();
}
// 全选内容
// @ts-ignore
e.key === 'a' && e.ctrlKey && e.target?.select();
}}
/>
{/* 发送和等待按键 */}
<Box position={'absolute'} bottom={5} right={media('20px', '10px')}>
{isChatting ? (
<Image
style={{ transform: 'translateY(4px)' }}
src={'/icon/chatting.svg'}
width={30}
height={30}
alt={''}
/>
) : (
<Box cursor={'pointer'} onClick={sendPrompt}>
<Icon
name={'chatSend'}
width={'20px'}
height={'20px'}
fill={useColorModeValue('#718096', 'white')}
></Icon>
</Box>
)}
</Box>
) : (
<Box
py={5}
position={'relative'}
boxShadow={`0 0 15px rgba(0,0,0,0.1)`}
border={media('1px solid', '0')}
borderColor={useColorModeValue('gray.200', 'gray.700')}
borderRadius={['none', 'md']}
backgroundColor={useColorModeValue('white', 'gray.700')}
>
{/* 输入框 */}
<Textarea
ref={TextareaDom}
w={'100%'}
pr={'45px'}
py={0}
border={'none'}
_focusVisible={{
border: 'none'
}}
placeholder="提问"
resize={'none'}
value={inputVal}
rows={1}
height={'22px'}
lineHeight={'22px'}
maxHeight={'150px'}
maxLength={chatData?.secret.contentMaxLen || -1}
overflowY={'auto'}
color={useColorModeValue('blackAlpha.700', 'white')}
onChange={(e) => {
const textarea = e.target;
setInputVal(textarea.value);
textarea.style.height = textareaMinH;
textarea.style.height = `${textarea.scrollHeight}px`;
}}
onKeyDown={(e) => {
// 触发快捷发送
if (isPc && e.keyCode === 13 && !e.shiftKey) {
sendPrompt();
e.preventDefault();
}
// 全选内容
// @ts-ignore
e.key === 'a' && e.ctrlKey && e.target?.select();
}}
/>
{/* 发送和等待按键 */}
<Box position={'absolute'} bottom={5} right={media('20px', '10px')}>
{isChatting ? (
<Image
style={{ transform: 'translateY(4px)' }}
src={'/icon/chatting.svg'}
width={30}
height={30}
alt={''}
/>
) : (
<Box cursor={'pointer'} onClick={sendPrompt}>
<Icon
name={'chatSend'}
width={'20px'}
height={'20px'}
fill={useColorModeValue('#718096', 'white')}
></Icon>
</Box>
)}
</Box>
</Box>
)}
</Box>
</Box>
</Flex>
</Flex>

View File

@@ -22,7 +22,7 @@ import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import { formatPrice } from '@/utils/user';
import { modelList, ChatModelNameEnum } from '@/constants/model';
import { encode, decode } from 'gpt-token-utils';
import { encode } from 'gpt-token-utils';
const fileExtension = '.txt,.doc,.docx,.pdf,.md';

View File

@@ -105,7 +105,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
/>
</Box>
</Card>
<Card p={4}>
{/* <Card p={4}>
<Box fontWeight={'bold'}>安全策略</Box>
<FormControl mt={2}>
<Flex alignItems={'center'}>
@@ -201,7 +201,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
<Box ml={3}>次</Box>
</Flex>
</FormControl>
</Card>
</Card> */}
</>
);
};

View File

@@ -31,7 +31,6 @@ export const jsonRes = <T = any>(
console.log('error->');
console.log('code:', error.code);
console.log('statusText:', error?.response?.statusText);
console.log('data len:', error?.response?.config?.data.length);
console.log('msg:', msg);
}

View File

@@ -32,14 +32,11 @@ export const authChat = async (chatId: string, authorization?: string) => {
return Promise.reject('模型不存在');
}
// 安全校验
if (chat.loadAmount === 0 || chat.expiredTime <= Date.now()) {
return Promise.reject('聊天框已过期');
}
// 分享校验
// 凭证校验
if (!chat.isShare) {
await authToken(authorization);
} else if (chat.loadAmount === 0 || chat.expiredTime <= Date.now()) {
return Promise.reject('聊天框已过期');
}
// 获取 user 的 apiKey
@@ -47,6 +44,7 @@ export const authChat = async (chatId: string, authorization?: string) => {
const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value;
// 没有 apikey ,校验余额
if (!userApiKey && formatPrice(user.balance) <= 0) {
return Promise.reject('该账号余额不足');
}

View File

@@ -4,6 +4,8 @@ import { User } from '../models/user';
import tunnel from 'tunnel';
import type { UserModelSchema } from '@/types/mongoSchema';
import { formatPrice } from '@/utils/user';
import { ChatItemType } from '@/types/chat';
import { encode } from 'gpt-token-utils';
/* 密码加密 */
export const hashPassword = (psw: string) => {
@@ -91,3 +93,29 @@ export const httpsAgent =
}
})
: undefined;
/* tokens 截断 */
export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) => {
let res: ChatItemType[] = [];
let systemPrompt: ChatItemType | null = null;
// System 词保留
if (prompts[0]?.obj === 'SYSTEM') {
systemPrompt = prompts.shift() as ChatItemType;
maxTokens -= encode(prompts[0].value).length;
}
// 从后往前截取
for (let i = prompts.length - 1; i >= 0; i--) {
const tokens = encode(prompts[i].value).length;
if (maxTokens >= tokens) {
res.unshift(prompts[i]);
maxTokens -= tokens;
} else {
break;
}
}
return systemPrompt ? [systemPrompt, ...res] : res;
};