feat: system prompt

This commit is contained in:
archer
2023-05-23 19:13:01 +08:00
parent b8f08eb33e
commit 6014a56e54
9 changed files with 118 additions and 61 deletions

View File

@@ -1,4 +1,4 @@
import { NEW_CHATID_HEADER, QUOTE_LEN_HEADER } from '@/constants/chat'; import { GUIDE_PROMPT_HEADER, NEW_CHATID_HEADER, QUOTE_LEN_HEADER } from '@/constants/chat';
interface StreamFetchProps { interface StreamFetchProps {
url: string; url: string;
@@ -7,7 +7,7 @@ interface StreamFetchProps {
abortSignal: AbortController; abortSignal: AbortController;
} }
export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
new Promise<{ responseText: string; newChatId: string; quoteLen: number }>( new Promise<{ responseText: string; newChatId: string; systemPrompt: string; quoteLen: number }>(
async (resolve, reject) => { async (resolve, reject) => {
try { try {
const res = await fetch(url, { const res = await fetch(url, {
@@ -24,6 +24,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
const decoder = new TextDecoder(); const decoder = new TextDecoder();
const newChatId = decodeURIComponent(res.headers.get(NEW_CHATID_HEADER) || ''); const newChatId = decodeURIComponent(res.headers.get(NEW_CHATID_HEADER) || '');
const systemPrompt = decodeURIComponent(res.headers.get(GUIDE_PROMPT_HEADER) || '').trim();
const quoteLen = res.headers.get(QUOTE_LEN_HEADER) const quoteLen = res.headers.get(QUOTE_LEN_HEADER)
? Number(res.headers.get(QUOTE_LEN_HEADER)) ? Number(res.headers.get(QUOTE_LEN_HEADER))
: 0; : 0;
@@ -35,7 +36,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
const { done, value } = await reader?.read(); const { done, value } = await reader?.read();
if (done) { if (done) {
if (res.status === 200) { if (res.status === 200) {
resolve({ responseText, newChatId, quoteLen }); resolve({ responseText, newChatId, quoteLen, systemPrompt });
} else { } else {
const parseError = JSON.parse(responseText); const parseError = JSON.parse(responseText);
reject(parseError?.message || '请求异常'); reject(parseError?.message || '请求异常');
@@ -49,7 +50,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
read(); read();
} catch (err: any) { } catch (err: any) {
if (err?.message === 'The user aborted a request.') { if (err?.message === 'The user aborted a request.') {
return resolve({ responseText, newChatId, quoteLen: 0 }); return resolve({ responseText, newChatId, quoteLen: 0, systemPrompt: '' });
} }
reject(typeof err === 'string' ? err : err?.message || '请求异常'); reject(typeof err === 'string' ? err : err?.message || '请求异常');
} }

View File

@@ -1,5 +1,6 @@
export const NEW_CHATID_HEADER = 'response-new-chat-id'; export const NEW_CHATID_HEADER = 'response-new-chat-id';
export const QUOTE_LEN_HEADER = 'response-quote-len'; export const QUOTE_LEN_HEADER = 'response-quote-len';
export const GUIDE_PROMPT_HEADER = 'response-guide-prompt';
export enum ChatRoleEnum { export enum ChatRoleEnum {
System = 'System', System = 'System',

View File

@@ -8,7 +8,7 @@ import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { resStreamResponse } from '@/service/utils/chat'; import { resStreamResponse } from '@/service/utils/chat';
import { appKbSearch } from '../openapi/kb/appKbSearch'; import { appKbSearch } from '../openapi/kb/appKbSearch';
import { ChatRoleEnum, QUOTE_LEN_HEADER } from '@/constants/chat'; import { ChatRoleEnum, QUOTE_LEN_HEADER, GUIDE_PROMPT_HEADER } from '@/constants/chat';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
import { sensitiveCheck } from '@/service/api/text'; import { sensitiveCheck } from '@/service/api/text';
import { NEW_CHATID_HEADER } from '@/constants/chat'; import { NEW_CHATID_HEADER } from '@/constants/chat';
@@ -53,11 +53,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { const {
code = 200, code = 200,
systemPrompts = [], systemPrompts = [],
quote = [] quote = [],
guidePrompt = ''
} = await (async () => { } = await (async () => {
// 使用了知识库搜索 // 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) { if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts, rawSearch } = await appKbSearch({ const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
model, model,
userId, userId,
prompts, prompts,
@@ -67,11 +68,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return { return {
code, code,
quote: rawSearch, quote: rawSearch,
systemPrompts: searchPrompts systemPrompts: searchPrompts,
guidePrompt
}; };
} }
if (model.chat.systemPrompt) { if (model.chat.systemPrompt) {
return { return {
guidePrompt: model.chat.systemPrompt,
systemPrompts: [ systemPrompts: [
{ {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
@@ -86,7 +89,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// get conversationId. create a newId if it is null // get conversationId. create a newId if it is null
const conversationId = chatId || String(new Types.ObjectId()); const conversationId = chatId || String(new Types.ObjectId());
!chatId && res.setHeader(NEW_CHATID_HEADER, conversationId); !chatId && res.setHeader(NEW_CHATID_HEADER, conversationId);
res.setHeader(QUOTE_LEN_HEADER, quote.length); if (showModelDetail) {
guidePrompt && res.setHeader(GUIDE_PROMPT_HEADER, encodeURIComponent(guidePrompt));
res.setHeader(QUOTE_LEN_HEADER, quote.length);
}
// search result is empty // search result is empty
if (code === 201) { if (code === 201) {
@@ -151,8 +157,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompt[0], prompt[0],
{ {
...prompt[1], ...prompt[1],
value: responseContent,
quote: showModelDetail ? quote : [], quote: showModelDetail ? quote : [],
value: responseContent systemPrompt: showModelDetail ? guidePrompt : ''
} }
], ],
userId userId

View File

@@ -73,7 +73,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
_id: '$content._id', _id: '$content._id',
obj: '$content.obj', obj: '$content.obj',
value: '$content.value', value: '$content.value',
quoteLen: { $size: '$content.quote' } systemPrompt: '$content.systemPrompt',
quoteLen: { $size: { $ifNull: ['$content.quote', []] } }
} }
} }
]); ]);

View File

@@ -57,10 +57,8 @@ export async function saveChat({
_id: item._id ? new mongoose.Types.ObjectId(item._id) : undefined, _id: item._id ? new mongoose.Types.ObjectId(item._id) : undefined,
obj: item.obj, obj: item.obj,
value: item.value, value: item.value,
quote: item.quote?.map((item) => ({ systemPrompt: item.systemPrompt,
...item, quote: item.quote || []
isEdit: false
}))
})); }));
// 没有 chatId, 创建一个对话 // 没有 chatId, 创建一个对话

View File

@@ -22,6 +22,7 @@ type Props = {
type Response = { type Response = {
code: 200 | 201; code: 200 | 201;
rawSearch: QuoteItemType[]; rawSearch: QuoteItemType[];
guidePrompt: string;
searchPrompts: { searchPrompts: {
obj: ChatRoleEnum; obj: ChatRoleEnum;
value: string; value: string;
@@ -131,36 +132,29 @@ export async function appKbSearch({
}; };
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0]; const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
// 计算固定提示词的 token 数量 // 计算固定提示词的 token 数量
const fixedPrompts = [
// user system prompt const guidePrompt = model.chat.systemPrompt // user system prompt
...(model.chat.systemPrompt ? {
? [ obj: ChatRoleEnum.System,
{ value: model.chat.systemPrompt
obj: ChatRoleEnum.System, }
value: model.chat.systemPrompt : model.chat.searchMode === ModelVectorSearchModeEnum.noContext
} ? {
] obj: ChatRoleEnum.System,
: model.chat.searchMode === ModelVectorSearchModeEnum.noContext value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.`
? [ }
{ : {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` value: `玩一个问答游戏,规则为:
}
]
: [
{
obj: ChatRoleEnum.System,
value: `玩一个问答游戏,规则为:
1.你完全忘记你已有的知识 1.你完全忘记你已有的知识
2.你只回答关于"${model.name}"的问题 2.你只回答关于"${model.name}"的问题
3.你只从知识库中选择内容进行回答 3.你只从知识库中选择内容进行回答
4.如果问题不在知识库中,你会回答:"我不知道。" 4.如果问题不在知识库中,你会回答:"我不知道。"
请务必遵守规则` 请务必遵守规则`
} };
])
];
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: fixedPrompts messages: [guidePrompt]
}); });
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
const sliceResult = sliceRate.map((rate, i) => const sliceResult = sliceRate.map((rate, i) =>
@@ -186,6 +180,7 @@ export async function appKbSearch({
return { return {
code: 201, code: 201,
rawSearch: [], rawSearch: [],
guidePrompt: '',
searchPrompts: [ searchPrompts: [
{ {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
@@ -199,6 +194,7 @@ export async function appKbSearch({
return { return {
code: 200, code: 200,
rawSearch: [], rawSearch: [],
guidePrompt: model.chat.systemPrompt || '',
searchPrompts: model.chat.systemPrompt searchPrompts: model.chat.systemPrompt
? [ ? [
{ {
@@ -213,12 +209,13 @@ export async function appKbSearch({
return { return {
code: 200, code: 200,
rawSearch: sliceSearch, rawSearch: sliceSearch,
guidePrompt: guidePrompt.value || '',
searchPrompts: [ searchPrompts: [
{ {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
value: `知识库:${systemPrompt}` value: `知识库:${systemPrompt}`
}, },
...fixedPrompts guidePrompt
] ]
}; };
} }

View File

@@ -76,6 +76,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
const isLeavePage = useRef(false); const isLeavePage = useRef(false);
const [showHistoryQuote, setShowHistoryQuote] = useState<string>(); const [showHistoryQuote, setShowHistoryQuote] = useState<string>();
const [showSystemPrompt, setShowSystemPrompt] = useState('');
const [messageContextMenuData, setMessageContextMenuData] = useState<{ const [messageContextMenuData, setMessageContextMenuData] = useState<{
// message messageContextMenuData // message messageContextMenuData
left: number; left: number;
@@ -177,7 +178,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
})); }));
// 流请求,获取数据 // 流请求,获取数据
const { newChatId, quoteLen } = await streamFetch({ const { newChatId, quoteLen, systemPrompt } = await streamFetch({
url: '/api/chat/chat', url: '/api/chat/chat',
data: { data: {
prompt, prompt,
@@ -221,14 +222,15 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
return { return {
...item, ...item,
status: 'finish', status: 'finish',
quoteLen quoteLen,
systemPrompt
}; };
}) })
})); }));
// refresh history // refresh history
loadHistory({ pageNum: 1, init: true });
setTimeout(() => { setTimeout(() => {
loadHistory({ pageNum: 1, init: true });
generatingMessage(); generatingMessage();
}, 100); }, 100);
}, },
@@ -699,6 +701,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
})} })}
> >
<Avatar <Avatar
className="avatar"
src={ src={
item.obj === 'Human' item.obj === 'Human'
? userInfo?.avatar || '/icon/human.png' ? userInfo?.avatar || '/icon/human.png'
@@ -727,19 +730,35 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
isChatting={isChatting && index === chatData.history.length - 1} isChatting={isChatting && index === chatData.history.length - 1}
formatLink formatLink
/> />
{!!item.quoteLen && ( <Flex>
<Button {!!item.systemPrompt && (
size={'xs'} <Button
mt={2} mt={2}
fontWeight={'normal'} mr={3}
colorScheme={'gray'} size={'xs'}
variant={'outline'} fontWeight={'normal'}
w={'90px'} colorScheme={'gray'}
onClick={() => setShowHistoryQuote(item._id)} variant={'outline'}
> px={[2, 4]}
{item.quoteLen} onClick={() => setShowSystemPrompt(item.systemPrompt || '')}
</Button> >
)}
</Button>
)}
{!!item.quoteLen && (
<Button
mt={2}
size={'xs'}
fontWeight={'normal'}
colorScheme={'gray'}
variant={'outline'}
px={[2, 4]}
onClick={() => setShowHistoryQuote(item._id)}
>
{item.quoteLen}
</Button>
)}
</Flex>
</Card> </Card>
</Box> </Box>
) : ( ) : (
@@ -876,6 +895,19 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
onClose={() => setShowHistoryQuote(undefined)} onClose={() => setShowHistoryQuote(undefined)}
/> />
)} )}
{/* system prompt show modal */}
{
<Modal isOpen={!!showSystemPrompt} onClose={() => setShowSystemPrompt('')}>
<ModalOverlay />
<ModalContent maxW={'min(90vw, 600px)'} maxH={'80vh'} minH={'50vh'} overflow={'overlay'}>
<ModalCloseButton />
<ModalHeader></ModalHeader>
<ModalBody pt={0} whiteSpace={'pre-wrap'} textAlign={'justify'} fontSize={'xs'}>
{showSystemPrompt}
</ModalBody>
</ModalContent>
</Modal>
}
{/* context menu */} {/* context menu */}
{messageContextMenuData && ( {messageContextMenuData && (
<Box <Box

View File

@@ -48,13 +48,32 @@ const ChatSchema = new Schema({
required: true required: true
}, },
quote: { quote: {
type: [{ id: String, q: String, a: String, isEdit: Boolean }], type: [
{
id: {
type: String,
required: true
},
q: {
type: String,
default: ''
},
a: {
type: String,
default: ''
},
isEdit: {
type: String,
default: false
}
}
],
default: [] default: []
},
systemPrompt: {
type: String,
default: ''
} }
// systemPrompt: {
// type: String,
// default: ''
// }
} }
], ],
default: [] default: []

1
src/types/chat.d.ts vendored
View File

@@ -9,6 +9,7 @@ export type ChatItemSimpleType = {
value: string; value: string;
quoteLen?: number; quoteLen?: number;
quote?: QuoteItemType[]; quote?: QuoteItemType[];
systemPrompt?: string;
}; };
export type ChatItemType = { export type ChatItemType = {
_id: string; _id: string;