feat: use last quote

This commit is contained in:
archer
2023-05-30 21:18:08 +08:00
parent 59ddf09b94
commit 0cde9a10a8
7 changed files with 86 additions and 81 deletions

View File

@@ -5,7 +5,7 @@ import { RequestPaging } from '../types/index';
import type { ShareChatSchema } from '@/types/mongoSchema';
import type { ShareChatEditType } from '@/types/model';
import { Obj2Query } from '@/utils/tools';
import { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch';
import type { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch';
import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory';
/**

View File

@@ -50,6 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [...content, prompt[0]];
const {
code = 200,
systemPrompts = [],
@@ -61,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
model,
userId,
prompts,
fixedQuote: content[content.length - 1]?.quote || [],
prompt: prompt[0],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
@@ -114,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.end(response);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -47,7 +47,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts } = await appKbSearch({
model,
userId,
prompts,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
@@ -74,7 +75,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.send(systemPrompts[0]?.value);
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -75,10 +75,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await appKbSearch({
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
userId
userId,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
});
// search result is empty
@@ -101,7 +102,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
];
}
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
prompts.unshift(...systemPrompts);
// content check
await sensitiveCheck({

View File

@@ -49,10 +49,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
});
const result = await appKbSearch({
model,
userId,
prompts,
similarity,
model
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity
});
jsonRes<Response>(res, {
@@ -70,67 +71,53 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function appKbSearch({
model,
userId,
prompts,
fixedQuote,
prompt,
similarity
}: {
userId: string;
prompts: ChatItemSimpleType[];
similarity: number;
model: ModelSchema;
userId: string;
fixedQuote: QuoteItemType[];
prompt: ChatItemSimpleType;
similarity: number;
}): Promise<Response> {
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// search two times.
const userPrompts = prompts.filter((item) => item.obj === 'Human');
const input: string[] = [
userPrompts[userPrompts.length - 1].value,
userPrompts[userPrompts.length - 2]?.value
].filter((item) => item);
// get vector
const promptVectors = await openaiEmbedding({
const promptVector = await openaiEmbedding({
userId,
input,
input: [prompt.value],
type: 'chat'
});
// search kb
const searchRes = await Promise.all(
promptVectors.map((promptVector) =>
PgClient.select<QuoteItemType>('modelData', {
fields: ['id', 'q', 'a'],
where: [
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: promptVectors.length === 1 ? 15 : 10
}).then((res) => res.rows)
)
);
const { rows: searchRes } = await PgClient.select<QuoteItemType>('modelData', {
fields: ['id', 'q', 'a'],
where: [
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
'AND',
`vector <=> '[${promptVector[0]}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector[0]}]'` }],
limit: 8
});
// filter same search result
const idSet = new Set<string>();
const filterSearch = searchRes.map((search) =>
search.filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
})
);
const filterSearch = [
...searchRes.slice(0, 3),
...fixedQuote.slice(0, 2),
...searchRes.slice(3),
...fixedQuote.slice(2, 5)
].filter((item) => {
if (idSet.has(item.id)) {
return false;
}
idSet.add(item.id);
return true;
});
// slice search result by rate.
const sliceRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
// 计算固定提示词的 token 数量
const guidePrompt = model.chat.systemPrompt // user system prompt
? {
obj: ChatRoleEnum.System,
@@ -154,24 +141,21 @@ export async function appKbSearch({
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: [guidePrompt]
});
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
const sliceResult = sliceRate.map((rate, i) =>
modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: Math.round(maxTokens * rate),
messages: filterSearch[i].map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value)
);
const sliceResult = modelToolMap[model.chat.chatModel]
.tokenSlice({
maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens,
messages: filterSearch.map((item) => ({
obj: ChatRoleEnum.System,
value: `${item.q}\n${item.a}`
}))
})
.map((item) => item.value);
// slice filterSearch
const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat();
const rawSearch = filterSearch.slice(0, sliceResult.length);
// system prompt
const systemPrompt = sliceResult.flat().join('\n').trim();
const systemPrompt = sliceResult.join('\n').trim();
/* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
@@ -206,7 +190,7 @@ export async function appKbSearch({
return {
code: 200,
rawSearch: sliceSearch,
rawSearch,
guidePrompt: guidePrompt.value || '',
searchPrompts: [
{

View File

@@ -280,7 +280,8 @@ export const authChat = async ({
{
$project: {
obj: '$content.obj',
value: '$content.value'
value: '$content.value',
quote: '$content.quote'
}
}
]);

View File

@@ -89,39 +89,55 @@ export const ChatContextFilter = ({
prompts: ChatItemSimpleType[];
maxTokens: number;
}) => {
const systemPrompts: ChatItemSimpleType[] = [];
const chatPrompts: ChatItemSimpleType[] = [];
let rawTextLen = 0;
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
prompts.forEach((item) => {
const val = simplifyStr(item.value);
rawTextLen += val.length;
return {
const data = {
obj: item.obj,
value: val
};
if (item.obj === ChatRoleEnum.System) {
systemPrompts.push(data);
} else {
chatPrompts.push(data);
}
});
// 长度太小时,不需要进行 token 截断
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
return formatPrompts;
if (rawTextLen < maxTokens * 0.5) {
return [...systemPrompts, ...chatPrompts];
}
// 去掉 system 的 token
maxTokens -= modelToolMap[model].countTokens({
messages: systemPrompts
});
// 根据 tokens 截断内容
const chats: ChatItemSimpleType[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
chats.unshift(formatPrompts[i]);
for (let i = chatPrompts.length - 1; i >= 0; i--) {
chats.unshift(chatPrompts[i]);
const tokens = modelToolMap[model].countTokens({
messages: chats
});
/* 整体 tokens 超出范围, system必须保留 */
if (tokens >= maxTokens && formatPrompts[i].obj !== ChatRoleEnum.System) {
return chats.slice(1);
if (tokens >= maxTokens) {
chats.shift();
break;
}
}
return chats;
return [...systemPrompts, ...chats];
};
/* stream response */