From 944e876aaa7333b57d9f585fbdd0bf8feaebc31a Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 23 May 2023 15:09:57 +0800 Subject: [PATCH] feat: chat quote --- src/api/chat.ts | 14 +- src/api/fetch.ts | 89 ++++--- src/constants/chat.ts | 3 +- src/pages/_app.tsx | 2 +- src/pages/api/chat/chat.ts | 179 ++++++++------ src/pages/api/chat/init.ts | 2 +- src/pages/api/chat/saveChat.ts | 104 ++++---- src/pages/api/chat/shareChat/chat.ts | 120 +++++----- src/pages/api/openapi/chat/chat.ts | 79 +++--- src/pages/api/openapi/chat/lastChatResult.ts | 38 +++ src/pages/api/openapi/kb/appKbSearch.ts | 224 ++++++++++++++++++ .../api/openapi/plugin/openaiEmbedding.ts | 77 ++++++ .../api/openapi/text/gptMessagesSlice.ts | 119 ++++++++++ src/pages/chat/index.tsx | 141 +++++------ src/pages/chat/share.tsx | 61 ++--- .../detail/components/ModelEditForm.tsx | 4 +- src/service/events/generateVector.ts | 16 +- src/service/models/chat.ts | 10 +- src/service/plugins/searchKb.ts | 175 -------------- src/service/utils/auth.ts | 17 +- src/service/utils/chat/claude.ts | 15 +- src/service/utils/chat/index.ts | 13 +- src/service/utils/chat/openai.ts | 58 +---- src/types/chat.d.ts | 3 +- src/utils/chat/claude.ts | 3 - src/utils/file.ts | 2 +- src/utils/{chat => plugin}/index.ts | 18 +- src/utils/{chat => plugin}/openai.ts | 0 src/utils/tools.ts | 7 + 29 files changed, 933 insertions(+), 660 deletions(-) create mode 100644 src/pages/api/openapi/chat/lastChatResult.ts create mode 100644 src/pages/api/openapi/kb/appKbSearch.ts create mode 100644 src/pages/api/openapi/plugin/openaiEmbedding.ts create mode 100644 src/pages/api/openapi/text/gptMessagesSlice.ts delete mode 100644 src/service/plugins/searchKb.ts delete mode 100644 src/utils/chat/claude.ts rename src/utils/{chat => plugin}/index.ts (64%) rename src/utils/{chat => plugin}/openai.ts (100%) diff --git a/src/api/chat.ts b/src/api/chat.ts index 93c755663..2e527492e 100644 --- a/src/api/chat.ts +++ b/src/api/chat.ts @@ -1,10 +1,11 @@ import { GET, POST, DELETE } from './request'; -import type { ChatItemType, HistoryItemType } from '@/types/chat'; +import type { HistoryItemType } from '@/types/chat'; import type { InitChatResponse, InitShareChatResponse } from './response/chat'; import { RequestPaging } from '../types/index'; import type { ShareChatSchema } from '@/types/mongoSchema'; import type { ShareChatEditType } from '@/types/model'; import { Obj2Query } from '@/utils/tools'; +import { Response as LastChatResultResponseType } from '@/pages/api/openapi/chat/lastChatResult'; /** * 获取初始化聊天内容 @@ -24,15 +25,10 @@ export const getChatHistory = (data: RequestPaging) => export const delChatHistoryById = (id: string) => GET(`/chat/removeHistory?id=${id}`); /** - * 存储一轮对话 + * get latest chat result by chatId */ -export const postSaveChat = (data: { - modelId: string; - newChatId: '' | string; - chatId: '' | string; - prompts: [ChatItemType, ChatItemType]; -}) => POST('/chat/saveChat', data); - +export const getChatResult = (chatId: string) => + GET('/openapi/chat/lastChatResult', { chatId }); /** * 删除一句对话 */ diff --git a/src/api/fetch.ts b/src/api/fetch.ts index 6042fda35..c47e0bad1 100644 --- a/src/api/fetch.ts +++ b/src/api/fetch.ts @@ -1,4 +1,4 @@ -import { SYSTEM_PROMPT_HEADER, NEW_CHATID_HEADER } from '@/constants/chat'; +import { NEW_CHATID_HEADER } from '@/constants/chat'; interface StreamFetchProps { url: string; @@ -7,55 +7,52 @@ interface StreamFetchProps { abortSignal: AbortController; } export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => - new Promise<{ responseText: string; systemPrompt: string; newChatId: string }>( - async (resolve, reject) => { - try { - const res = await fetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify(data), - signal: abortSignal.signal - }); - const reader = res.body?.getReader(); - if (!reader) return; + new Promise<{ responseText: string; newChatId: string }>(async (resolve, reject) => { + try { + const res = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data), + signal: abortSignal.signal + }); + const reader = res.body?.getReader(); + if (!reader) return; - const decoder = new TextDecoder(); + const decoder = new TextDecoder(); - const systemPrompt = decodeURIComponent(res.headers.get(SYSTEM_PROMPT_HEADER) || '').trim(); - const newChatId = decodeURIComponent(res.headers.get(NEW_CHATID_HEADER) || ''); + const newChatId = decodeURIComponent(res.headers.get(NEW_CHATID_HEADER) || ''); - let responseText = ''; + let responseText = ''; - const read = async () => { - try { - const { done, value } = await reader?.read(); - if (done) { - if (res.status === 200) { - resolve({ responseText, systemPrompt, newChatId }); - } else { - const parseError = JSON.parse(responseText); - reject(parseError?.message || '请求异常'); - } - - return; + const read = async () => { + try { + const { done, value } = await reader?.read(); + if (done) { + if (res.status === 200) { + resolve({ responseText, newChatId }); + } else { + const parseError = JSON.parse(responseText); + reject(parseError?.message || '请求异常'); } - const text = decoder.decode(value); - responseText += text; - onMessage(text); - read(); - } catch (err: any) { - if (err?.message === 'The user aborted a request.') { - return resolve({ responseText, systemPrompt, newChatId }); - } - reject(typeof err === 'string' ? err : err?.message || '请求异常'); + + return; } - }; - read(); - } catch (err: any) { - console.log(err, '===='); - reject(typeof err === 'string' ? err : err?.message || '请求异常'); - } + const text = decoder.decode(value); + responseText += text; + onMessage(text); + read(); + } catch (err: any) { + if (err?.message === 'The user aborted a request.') { + return resolve({ responseText, newChatId }); + } + reject(typeof err === 'string' ? err : err?.message || '请求异常'); + } + }; + read(); + } catch (err: any) { + console.log(err, '===='); + reject(typeof err === 'string' ? err : err?.message || '请求异常'); } - ); + }); diff --git a/src/constants/chat.ts b/src/constants/chat.ts index 7f580ff34..2c7cc579f 100644 --- a/src/constants/chat.ts +++ b/src/constants/chat.ts @@ -1,5 +1,4 @@ -export const SYSTEM_PROMPT_HEADER = 'System-Prompt-Header'; -export const NEW_CHATID_HEADER = 'Chat-Id-Header'; +export const NEW_CHATID_HEADER = 'response-new-chat-id'; export enum ChatRoleEnum { System = 'System', diff --git a/src/pages/_app.tsx b/src/pages/_app.tsx index 50f63e2ca..fef78b9cc 100644 --- a/src/pages/_app.tsx +++ b/src/pages/_app.tsx @@ -55,7 +55,7 @@ export default function App({ Component, pageProps }: AppProps) { - + diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index dcc459359..9301db82e 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -2,19 +2,24 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; import { authChat } from '@/service/utils/auth'; import { modelServiceToolMap } from '@/service/utils/chat'; -import { ChatItemSimpleType } from '@/types/chat'; +import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { resStreamResponse } from '@/service/utils/chat'; -import { searchKb } from '@/service/plugins/searchKb'; +import { appKbSearch } from '../openapi/kb/appKbSearch'; import { ChatRoleEnum } from '@/constants/chat'; import { BillTypeEnum } from '@/constants/user'; import { sensitiveCheck } from '@/service/api/text'; +import { NEW_CHATID_HEADER } from '@/constants/chat'; +import { saveChat } from './saveChat'; +import { Types } from 'mongoose'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { - let step = 0; // step=1时,表示开始了流响应 + res.on('close', () => { + res.end(); + }); res.on('error', () => { console.log('error: ', 'request error'); res.end(); @@ -22,9 +27,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) try { const { chatId, prompt, modelId } = req.body as { - prompt: ChatItemSimpleType; + prompt: [ChatItemType, ChatItemType]; modelId: string; - chatId: '' | string; + chatId?: string; }; if (!modelId || !prompt) { @@ -44,42 +49,69 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 读取对话内容 - const prompts = [...content, prompt]; - let systemPrompts: { - obj: ChatRoleEnum; - value: string; - }[] = []; + const prompts = [...content, prompt[0]]; + const { + code = 200, + systemPrompts = [], + quote = [] + } = await (async () => { + // 使用了知识库搜索 + if (model.chat.relatedKbs.length > 0) { + const { code, searchPrompts, rawSearch } = await appKbSearch({ + model, + userId, + prompts, + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity + }); - // 使用了知识库搜索 - if (model.chat.relatedKbs.length > 0) { - const { code, searchPrompts } = await searchKb({ - userOpenAiKey, - prompts, - similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, - model, + return { + code, + quote: rawSearch, + systemPrompts: searchPrompts + }; + } + if (model.chat.systemPrompt) { + return { + systemPrompts: [ + { + obj: ChatRoleEnum.System, + value: model.chat.systemPrompt + } + ] + }; + } + return {}; + })(); + + // get conversationId. create a newId if it is null + const conversationId = chatId || String(new Types.ObjectId()); + !chatId && res?.setHeader(NEW_CHATID_HEADER, conversationId); + + // search result is empty + if (code === 201) { + const response = systemPrompts[0]?.value; + await saveChat({ + chatId, + newChatId: conversationId, + modelId, + prompts: [ + prompt[0], + { + ...prompt[1], + quote: [], + value: response + } + ], userId }); - - // search result is empty - if (code === 201) { - return res.send(searchPrompts[0]?.value); - } - - systemPrompts = searchPrompts; - } else if (model.chat.systemPrompt) { - systemPrompts = [ - { - obj: ChatRoleEnum.System, - value: model.chat.systemPrompt - } - ]; + return res.end(response); } prompts.splice(prompts.length - 3, 0, ...systemPrompts); // content check await sensitiveCheck({ - input: [...systemPrompts, prompt].map((item) => item.value).join('') + input: [...systemPrompts, prompt[0]].map((item) => item.value).join('') }); // 计算温度 @@ -87,54 +119,65 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) 2 ); - // 发出请求 + // 发出 chat 请求 const { streamResponse } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({ apiKey: userOpenAiKey || systemAuthKey, temperature: +temperature, messages: prompts, stream: true, res, - chatId + chatId: conversationId }); console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); - step = 1; + if (res.closed) return res.end(); - const { totalTokens, finishMessages } = await resStreamResponse({ - model: model.chat.chatModel, - res, - chatResponse: streamResponse, - prompts, - systemPrompt: showModelDetail - ? prompts - .filter((item) => item.obj === ChatRoleEnum.System) - .map((item) => item.value) - .join('\n') - : '' - }); - - // 只有使用平台的 key 才计费 - pushChatBill({ - isPay: !userOpenAiKey, - chatModel: model.chat.chatModel, - userId, - chatId, - textLen: finishMessages.map((item) => item.value).join('').length, - tokens: totalTokens, - type: BillTypeEnum.chat - }); - } catch (err: any) { - if (step === 1) { - // 直接结束流 - res.end(); - console.log('error,结束'); - } else { - res.status(500); - jsonRes(res, { - code: 500, - error: err + try { + const { totalTokens, finishMessages, responseContent } = await resStreamResponse({ + model: model.chat.chatModel, + res, + chatResponse: streamResponse, + prompts }); + + // save chat + await saveChat({ + chatId, + newChatId: conversationId, + modelId, + prompts: [ + prompt[0], + { + ...prompt[1], + quote: showModelDetail ? quote : [], + value: responseContent + } + ], + userId + }); + + res.end(); + + // 只有使用平台的 key 才计费 + pushChatBill({ + isPay: !userOpenAiKey, + chatModel: model.chat.chatModel, + userId, + chatId: conversationId, + textLen: finishMessages.map((item) => item.value).join('').length, + tokens: totalTokens, + type: BillTypeEnum.chat + }); + } catch (error) { + res.end(); + console.log('error,结束', error); } + } catch (err: any) { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); } } diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index 7d68b454e..f24174ca5 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -73,7 +73,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) _id: '$content._id', obj: '$content.obj', value: '$content.value', - systemPrompt: '$content.systemPrompt' + quote: '$content.quote' } } ]); diff --git a/src/pages/api/chat/saveChat.ts b/src/pages/api/chat/saveChat.ts index 4205a24b3..d2e1e934f 100644 --- a/src/pages/api/chat/saveChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -6,15 +6,17 @@ import { authModel } from '@/service/utils/auth'; import { authUser } from '@/service/utils/auth'; import mongoose from 'mongoose'; +type Props = { + newChatId?: string; + chatId?: string; + modelId: string; + prompts: [ChatItemType, ChatItemType]; +}; + /* 聊天内容存存储 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { chatId, modelId, prompts, newChatId } = req.body as { - newChatId: '' | string; - chatId: '' | string; - modelId: string; - prompts: [ChatItemType, ChatItemType]; - }; + const { chatId, modelId, prompts, newChatId } = req.body as Props; if (!prompts) { throw new Error('缺少参数'); @@ -22,44 +24,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { userId } = await authUser({ req, authToken: true }); - await connectToDatabase(); + const nId = await saveChat({ + chatId, + modelId, + prompts, + newChatId, + userId + }); - const content = prompts.map((item) => ({ - _id: new mongoose.Types.ObjectId(item._id), - obj: item.obj, - value: item.value, - systemPrompt: item.systemPrompt - })); - - await authModel({ modelId, userId, authOwner: false }); - - // 没有 chatId, 创建一个对话 - if (!chatId) { - const { _id } = await Chat.create({ - _id: newChatId ? new mongoose.Types.ObjectId(newChatId) : undefined, - userId, - modelId, - content, - title: content[0].value.slice(0, 20), - latestChat: content[1].value - }); - return jsonRes(res, { - data: _id - }); - } else { - // 已经有记录,追加入库 - await Chat.findByIdAndUpdate(chatId, { - $push: { - content: { - $each: content - } - }, - title: content[0].value.slice(0, 20), - latestChat: content[1].value, - updateTime: new Date() - }); - } - jsonRes(res); + jsonRes(res, { + data: nId + }); } catch (err) { jsonRes(res, { code: 500, @@ -67,3 +42,46 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) }); } } + +export async function saveChat({ + chatId, + newChatId, + modelId, + prompts, + userId +}: Props & { userId: string }) { + await connectToDatabase(); + await authModel({ modelId, userId, authOwner: false }); + + const content = prompts.map((item) => ({ + _id: item._id ? new mongoose.Types.ObjectId(item._id) : undefined, + obj: item.obj, + value: item.value, + quote: item.quote + })); + + // 没有 chatId, 创建一个对话 + if (!chatId) { + const { _id } = await Chat.create({ + _id: newChatId ? new mongoose.Types.ObjectId(newChatId) : undefined, + userId, + modelId, + content, + title: content[0].value.slice(0, 20), + latestChat: content[1].value + }); + return _id; + } else { + // 已经有记录,追加入库 + await Chat.findByIdAndUpdate(chatId, { + $push: { + content: { + $each: content + } + }, + title: content[0].value.slice(0, 20), + latestChat: content[1].value, + updateTime: new Date() + }); + } +} diff --git a/src/pages/api/chat/shareChat/chat.ts b/src/pages/api/chat/shareChat/chat.ts index 3ebbbcfde..b32bd8ccc 100644 --- a/src/pages/api/chat/shareChat/chat.ts +++ b/src/pages/api/chat/shareChat/chat.ts @@ -7,14 +7,13 @@ import { jsonRes } from '@/service/response'; import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill, updateShareChatBill } from '@/service/events/pushBill'; import { resStreamResponse } from '@/service/utils/chat'; -import { searchKb } from '@/service/plugins/searchKb'; import { ChatRoleEnum } from '@/constants/chat'; import { BillTypeEnum } from '@/constants/user'; import { sensitiveCheck } from '@/service/api/text'; +import { appKbSearch } from '../../openapi/kb/appKbSearch'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { - let step = 0; // step=1 时,表示开始了流响应 res.on('error', () => { console.log('error: ', 'request error'); res.end(); @@ -42,34 +41,37 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const modelConstantsData = ChatModelMap[model.chat.chatModel]; - let systemPrompts: { - obj: ChatRoleEnum; - value: string; - }[] = []; + const { code = 200, systemPrompts = [] } = await (async () => { + // 使用了知识库搜索 + if (model.chat.relatedKbs.length > 0) { + const { code, searchPrompts } = await appKbSearch({ + model, + userId, + prompts, + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity + }); - // 使用了知识库搜索 - if (model.chat.relatedKbs.length > 0) { - const { code, searchPrompts } = await searchKb({ - userOpenAiKey, - prompts, - similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, - model, - userId - }); - - // search result is empty - if (code === 201) { - return res.send(searchPrompts[0]?.value); + return { + code, + systemPrompts: searchPrompts + }; } + if (model.chat.systemPrompt) { + return { + systemPrompts: [ + { + obj: ChatRoleEnum.System, + value: model.chat.systemPrompt + } + ] + }; + } + return {}; + })(); - systemPrompts = searchPrompts; - } else if (model.chat.systemPrompt) { - systemPrompts = [ - { - obj: ChatRoleEnum.System, - value: model.chat.systemPrompt - } - ]; + // search result is empty + if (code === 201) { + return res.send(systemPrompts[0]?.value); } prompts.splice(prompts.length - 3, 0, ...systemPrompts); @@ -96,40 +98,40 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); - step = 1; + if (res.closed) return res.end(); - const { totalTokens, finishMessages } = await resStreamResponse({ - model: model.chat.chatModel, - res, - chatResponse: streamResponse, - prompts, - systemPrompt: '' - }); - - /* bill */ - pushChatBill({ - isPay: !userOpenAiKey, - chatModel: model.chat.chatModel, - userId, - textLen: finishMessages.map((item) => item.value).join('').length, - tokens: totalTokens, - type: BillTypeEnum.chat - }); - updateShareChatBill({ - shareId, - tokens: totalTokens - }); - } catch (err: any) { - if (step === 1) { - // 直接结束流 - res.end(); - console.log('error,结束'); - } else { - res.status(500); - jsonRes(res, { - code: 500, - error: err + try { + const { totalTokens, finishMessages } = await resStreamResponse({ + model: model.chat.chatModel, + res, + chatResponse: streamResponse, + prompts }); + + res.end(); + + /* bill */ + pushChatBill({ + isPay: !userOpenAiKey, + chatModel: model.chat.chatModel, + userId, + textLen: finishMessages.map((item) => item.value).join('').length, + tokens: totalTokens, + type: BillTypeEnum.chat + }); + updateShareChatBill({ + shareId, + tokens: totalTokens + }); + } catch (error) { + res.end(); + console.log('error,结束', error); } + } catch (err: any) { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); } } diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index e15d861cd..a03853cd6 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -6,15 +6,19 @@ import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { searchKb } from '@/service/plugins/searchKb'; import { ChatRoleEnum } from '@/constants/chat'; import { withNextCors } from '@/service/utils/tools'; import { BillTypeEnum } from '@/constants/user'; import { sensitiveCheck } from '@/service/api/text'; +import { NEW_CHATID_HEADER } from '@/constants/chat'; +import { Types } from 'mongoose'; +import { appKbSearch } from '../kb/appKbSearch'; /* 发送提示词 */ export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { - let step = 0; // step=1时,表示开始了流响应 + res.on('close', () => { + res.end(); + }); res.on('error', () => { console.log('error: ', 'request error'); res.end(); @@ -70,7 +74,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 使用了知识库搜索 if (model.chat.relatedKbs.length > 0) { - const { code, searchPrompts } = await searchKb({ + const { code, searchPrompts } = await appKbSearch({ prompts, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, model, @@ -109,6 +113,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex 2 ); + // get conversationId. create a newId if it is null + const conversationId = chatId || String(new Types.ObjectId()); + !chatId && res?.setHeader(NEW_CHATID_HEADER, conversationId); + // 发出请求 const { streamResponse, responseMessages, responseText, totalTokens } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({ @@ -117,30 +125,41 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex messages: prompts, stream: isStream, res, - chatId + chatId: conversationId }); console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); - let textLen = 0; - let tokens = totalTokens; + if (res.closed) return res.end(); - if (isStream) { - step = 1; - const { finishMessages, totalTokens } = await resStreamResponse({ - model: model.chat.chatModel, - res, - chatResponse: streamResponse, - prompts - }); - textLen = finishMessages.map((item) => item.value).join('').length; - tokens = totalTokens; - } else { - textLen = responseMessages.map((item) => item.value).join('').length; - jsonRes(res, { - data: responseText - }); - } + const { textLen = 0, tokens = totalTokens } = await (async () => { + if (isStream) { + try { + const { finishMessages, totalTokens } = await resStreamResponse({ + model: model.chat.chatModel, + res, + chatResponse: streamResponse, + prompts + }); + res.end(); + return { + textLen: finishMessages.map((item) => item.value).join('').length, + tokens: totalTokens + }; + } catch (error) { + res.end(); + console.log('error,结束', error); + } + } else { + jsonRes(res, { + data: responseText + }); + return { + textLen: responseMessages.map((item) => item.value).join('').length + }; + } + return {}; + })(); pushChatBill({ isPay: true, @@ -151,16 +170,10 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex type: BillTypeEnum.openapiChat }); } catch (err: any) { - if (step === 1) { - // 直接结束流 - res.end(); - console.log('error,结束'); - } else { - res.status(500); - jsonRes(res, { - code: 500, - error: err - }); - } + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); } }); diff --git a/src/pages/api/openapi/chat/lastChatResult.ts b/src/pages/api/openapi/chat/lastChatResult.ts new file mode 100644 index 000000000..15147f0ad --- /dev/null +++ b/src/pages/api/openapi/chat/lastChatResult.ts @@ -0,0 +1,38 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { Chat } from '@/service/mongo'; +import { authUser } from '@/service/utils/auth'; +import { QuoteItemType } from '../kb/appKbSearch'; + +type Props = { + chatId: string; +}; +export type Response = { + quote: QuoteItemType[]; +}; + +/* 聊天内容存存储 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { chatId } = req.query as Props; + + if (!chatId) { + throw new Error('缺少参数'); + } + + const { userId } = await authUser({ req }); + + const chatItem = await Chat.findOne({ _id: chatId, userId }, { content: { $slice: -1 } }); + + jsonRes(res, { + data: { + quote: chatItem?.content[0]?.quote || [] + } + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/openapi/kb/appKbSearch.ts b/src/pages/api/openapi/kb/appKbSearch.ts new file mode 100644 index 000000000..6b831d259 --- /dev/null +++ b/src/pages/api/openapi/kb/appKbSearch.ts @@ -0,0 +1,224 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { authUser } from '@/service/utils/auth'; +import { PgClient } from '@/service/pg'; +import { withNextCors } from '@/service/utils/tools'; +import type { ChatItemSimpleType } from '@/types/chat'; +import type { ModelSchema } from '@/types/mongoSchema'; +import { ModelVectorSearchModeEnum } from '@/constants/model'; +import { authModel } from '@/service/utils/auth'; +import { ChatModelMap } from '@/constants/model'; +import { ChatRoleEnum } from '@/constants/chat'; +import { openaiEmbedding } from '../plugin/openaiEmbedding'; +import { ModelDataStatusEnum } from '@/constants/model'; +import { modelToolMap } from '@/utils/plugin'; + +export type QuoteItemType = { id: string; q: string; a: string }; +type Props = { + prompts: ChatItemSimpleType[]; + similarity: number; + appId: string; +}; +type Response = { + code: 200 | 201; + rawSearch: QuoteItemType[]; + searchPrompts: { + obj: ChatRoleEnum; + value: string; + }[]; +}; + +export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { userId } = await authUser({ req }); + + if (!userId) { + throw new Error('userId is empty'); + } + + const { prompts, similarity, appId } = req.body as Props; + + if (!similarity || !Array.isArray(prompts) || !appId) { + throw new Error('params is error'); + } + + // auth model + const { model } = await authModel({ + modelId: appId, + userId + }); + + const result = await appKbSearch({ + userId, + prompts, + similarity, + model + }); + + jsonRes(res, { + data: result + }); + } catch (err) { + console.log(err); + jsonRes(res, { + code: 500, + error: err + }); + } +}); + +export async function appKbSearch({ + model, + userId, + prompts, + similarity +}: { + userId: string; + prompts: ChatItemSimpleType[]; + similarity: number; + model: ModelSchema; +}): Promise { + const modelConstantsData = ChatModelMap[model.chat.chatModel]; + + // search two times. + const userPrompts = prompts.filter((item) => item.obj === 'Human'); + + const input: string[] = [ + userPrompts[userPrompts.length - 1].value, + userPrompts[userPrompts.length - 2]?.value + ].filter((item) => item); + + // get vector + const promptVectors = await openaiEmbedding({ + userId, + input + }); + + // search kb + const searchRes = await Promise.all( + promptVectors.map((promptVector) => + PgClient.select<{ id: string; q: string; a: string }>('modelData', { + fields: ['id', 'q', 'a'], + where: [ + ['status', ModelDataStatusEnum.ready], + 'AND', + `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, + 'AND', + `vector <=> '[${promptVector}]' < ${similarity}` + ], + order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], + limit: promptVectors.length === 1 ? 15 : 10 + }).then((res) => res.rows) + ) + ); + + // filter same search result + const idSet = new Set(); + const filterSearch = searchRes.map((search) => + search.filter((item) => { + if (idSet.has(item.id)) { + return false; + } + idSet.add(item.id); + return true; + }) + ); + + // slice search result by rate. + const sliceRateMap: Record = { + 1: [1], + 2: [0.7, 0.3] + }; + const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0]; + // 计算固定提示词的 token 数量 + const fixedPrompts = [ + // user system prompt + ...(model.chat.systemPrompt + ? [ + { + obj: ChatRoleEnum.System, + value: model.chat.systemPrompt + } + ] + : model.chat.searchMode === ModelVectorSearchModeEnum.noContext + ? [ + { + obj: ChatRoleEnum.System, + value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` + } + ] + : [ + { + obj: ChatRoleEnum.System, + value: `玩一个问答游戏,规则为: +1.你完全忘记你已有的知识 +2.你只回答关于"${model.name}"的问题 +3.你只从知识库中选择内容进行回答 +4.如果问题不在知识库中,你会回答:"我不知道。" +请务必遵守规则` + } + ]) + ]; + const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ + messages: fixedPrompts + }); + const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; + const sliceResult = sliceRate.map((rate, i) => + modelToolMap[model.chat.chatModel] + .tokenSlice({ + maxToken: Math.round(maxTokens * rate), + messages: filterSearch[i].map((item) => ({ + obj: ChatRoleEnum.System, + value: `${item.q}\n${item.a}` + })) + }) + .map((item) => item.value) + ); + + // slice filterSearch + const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat(); + + // system prompt + const systemPrompt = sliceResult.flat().join('\n').trim(); + + /* 高相似度+不回复 */ + if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) { + return { + code: 201, + rawSearch: [], + searchPrompts: [ + { + obj: ChatRoleEnum.System, + value: '对不起,你的问题不在知识库中。' + } + ] + }; + } + /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ + if (!systemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) { + return { + code: 200, + rawSearch: [], + searchPrompts: model.chat.systemPrompt + ? [ + { + obj: ChatRoleEnum.System, + value: model.chat.systemPrompt + } + ] + : [] + }; + } + + return { + code: 200, + rawSearch: sliceSearch, + searchPrompts: [ + { + obj: ChatRoleEnum.System, + value: `知识库:${systemPrompt}` + }, + ...fixedPrompts + ] + }; +} diff --git a/src/pages/api/openapi/plugin/openaiEmbedding.ts b/src/pages/api/openapi/plugin/openaiEmbedding.ts new file mode 100644 index 000000000..d1bb776ba --- /dev/null +++ b/src/pages/api/openapi/plugin/openaiEmbedding.ts @@ -0,0 +1,77 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { authUser } from '@/service/utils/auth'; +import { PgClient } from '@/service/pg'; +import { withNextCors } from '@/service/utils/tools'; +import { getApiKey } from '@/service/utils/auth'; +import { getOpenAIApi } from '@/service/utils/chat/openai'; +import { embeddingModel } from '@/constants/model'; +import { axiosConfig } from '@/service/utils/tools'; +import { pushGenerateVectorBill } from '@/service/events/pushBill'; + +type Props = { + input: string[]; +}; +type Response = number[][]; + +export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { userId } = await authUser({ req }); + let { input } = req.query as Props; + + if (!Array.isArray(input)) { + throw new Error('缺少参数'); + } + + jsonRes(res, { + data: await openaiEmbedding({ userId, input, mustPay: true }) + }); + } catch (err) { + console.log(err); + jsonRes(res, { + code: 500, + error: err + }); + } +}); + +export async function openaiEmbedding({ + userId, + input, + mustPay = false +}: { userId: string; mustPay?: boolean } & Props) { + const { userOpenAiKey, systemAuthKey } = await getApiKey({ + model: 'gpt-3.5-turbo', + userId, + mustPay + }); + + // 获取 chatAPI + const chatAPI = getOpenAIApi(); + + // 把输入的内容转成向量 + const result = await chatAPI + .createEmbedding( + { + model: embeddingModel, + input + }, + { + timeout: 60000, + ...axiosConfig(userOpenAiKey || systemAuthKey) + } + ) + .then((res) => ({ + tokenLen: res.data.usage.total_tokens || 0, + vectors: res.data.data.map((item) => item.embedding) + })); + + pushGenerateVectorBill({ + isPay: !userOpenAiKey, + userId, + text: input.join(''), + tokenLen: result.tokenLen + }); + + return result.vectors; +} diff --git a/src/pages/api/openapi/text/gptMessagesSlice.ts b/src/pages/api/openapi/text/gptMessagesSlice.ts new file mode 100644 index 000000000..f376135d0 --- /dev/null +++ b/src/pages/api/openapi/text/gptMessagesSlice.ts @@ -0,0 +1,119 @@ +// Next.js API route support: https://nextjs.org/docs/api-routes/introduction +import type { NextApiRequest, NextApiResponse } from 'next'; +import { type Tiktoken } from '@dqbd/tiktoken'; +import { jsonRes } from '@/service/response'; +import { authUser } from '@/service/utils/auth'; +import Graphemer from 'graphemer'; +import type { ChatItemSimpleType } from '@/types/chat'; +import { ChatCompletionRequestMessage } from 'openai'; +import { getOpenAiEncMap } from '@/utils/plugin/openai'; +import { adaptChatItem_openAI } from '@/utils/plugin/openai'; + +type ModelType = 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'; + +type Props = { + messages: ChatItemSimpleType[]; + model: ModelType; + maxLen: number; +}; +type Response = ChatItemSimpleType[]; + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + await authUser({ req }); + + const { messages, model, maxLen } = req.body as Props; + + if (!Array.isArray(messages) || !model || !maxLen) { + throw new Error('params is error'); + } + + return jsonRes(res, { + data: gpt_chatItemTokenSlice({ + messages, + model, + maxToken: maxLen + }) + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} + +export function gpt_chatItemTokenSlice({ + messages, + model, + maxToken +}: { + messages: ChatItemSimpleType[]; + model: ModelType; + maxToken: number; +}) { + const textDecoder = new TextDecoder(); + const graphemer = new Graphemer(); + + function getChatGPTEncodingText(messages: ChatCompletionRequestMessage[], model: ModelType) { + const isGpt3 = model === 'gpt-3.5-turbo'; + + const msgSep = isGpt3 ? '\n' : ''; + const roleSep = isGpt3 ? '\n' : '<|im_sep|>'; + + return [ + messages + .map(({ name = '', role, content }) => { + return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`; + }) + .join(msgSep), + `<|im_start|>assistant${roleSep}` + ].join(msgSep); + } + function text2TokensLen(encoder: Tiktoken, inputText: string) { + const encoding = encoder.encode(inputText, 'all'); + const segments: { text: string; tokens: { id: number; idx: number }[] }[] = []; + + let byteAcc: number[] = []; + let tokenAcc: { id: number; idx: number }[] = []; + let inputGraphemes = graphemer.splitGraphemes(inputText); + + for (let idx = 0; idx < encoding.length; idx++) { + const token = encoding[idx]!; + byteAcc.push(...encoder.decode_single_token_bytes(token)); + tokenAcc.push({ id: token, idx }); + + const segmentText = textDecoder.decode(new Uint8Array(byteAcc)); + const graphemes = graphemer.splitGraphemes(segmentText); + + if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) { + segments.push({ text: segmentText, tokens: tokenAcc }); + + byteAcc = []; + tokenAcc = []; + inputGraphemes = inputGraphemes.slice(graphemes.length); + } + } + + return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0; + } + const OpenAiEncMap = getOpenAiEncMap(); + const enc = OpenAiEncMap[model]; + + let result: ChatItemSimpleType[] = []; + + for (let i = 0; i < messages.length; i++) { + const msgs = [...result, messages[i]]; + const tokens = text2TokensLen( + enc, + getChatGPTEncodingText(adaptChatItem_openAI({ messages }), model) + ); + if (tokens < maxToken) { + result = msgs; + } else { + break; + } + } + + return result; +} diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 65affb71e..811d71f53 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -3,10 +3,10 @@ import { useRouter } from 'next/router'; import { getInitChatSiteInfo, delChatRecordByIndex, - postSaveChat, + getChatResult, delChatHistoryById } from '@/api/chat'; -import type { ChatSiteItemType, ExportChatType } from '@/types/chat'; +import type { ChatItemType, ChatSiteItemType, ExportChatType } from '@/types/chat'; import { Textarea, Box, @@ -29,13 +29,14 @@ import { Card, Tooltip, useOutsideClick, - useTheme + useTheme, + ModalHeader } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useGlobalStore } from '@/store/global'; import { useQuery } from '@tanstack/react-query'; import dynamic from 'next/dynamic'; -import { useCopyData, voiceBroadcast, hasVoiceApi } from '@/utils/tools'; +import { useCopyData, voiceBroadcast, hasVoiceApi, delay } from '@/utils/tools'; import { streamFetch } from '@/api/fetch'; import MyIcon from '@/components/Icon'; import { throttle } from 'lodash'; @@ -47,6 +48,7 @@ import { useLoading } from '@/hooks/useLoading'; import { fileDownload } from '@/utils/file'; import { htmlTemplate } from '@/constants/common'; import { useUserStore } from '@/store/user'; +import type { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch'; import Loading from '@/components/Loading'; import Markdown from '@/components/Markdown'; import SideBar from '@/components/SideBar'; @@ -78,7 +80,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { const controller = useRef(new AbortController()); const isLeavePage = useRef(false); - const [showSystemPrompt, setShowSystemPrompt] = useState(''); + const [showQuote, setShowQuote] = useState([]); const [messageContextMenuData, setMessageContextMenuData] = useState<{ // message messageContextMenuData left: number; @@ -173,13 +175,14 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { controller.current = abortSignal; isLeavePage.current = false; - const prompt = { - obj: prompts[0].obj, - value: prompts[0].value - }; + const prompt: ChatItemType[] = prompts.map((item) => ({ + _id: item._id, + obj: item.obj, + value: item.value + })); // 流请求,获取数据 - let { responseText, systemPrompt, newChatId } = await streamFetch({ + const { newChatId } = await streamFetch({ url: '/api/chat/chat', data: { prompt, @@ -207,39 +210,16 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { return; } - // save chat record - try { - newChatId = await postSaveChat({ - newChatId, // 如果有newChatId,会自动以这个Id创建对话框 - modelId, - chatId, - prompts: [ - { - _id: prompts[0]._id, - obj: 'Human', - value: prompt.value - }, - { - _id: prompts[1]._id, - obj: 'AI', - value: responseText, - systemPrompt - } - ] - }); - if (newChatId) { - setForbidLoadChatData(true); - router.replace(`/chat?modelId=${modelId}&chatId=${newChatId}`); - } - } catch (err) { - toast({ - title: '对话出现异常, 继续对话会导致上下文丢失,请刷新页面', - status: 'warning', - duration: 3000, - isClosable: true - }); + if (newChatId) { + setForbidLoadChatData(true); + router.replace(`/chat?modelId=${modelId}&chatId=${newChatId}`); } + abortSignal.signal.aborted && (await delay(600)); + + // get chat result + const { quote } = await getChatResult(chatId || newChatId); + // 设置聊天内容为完成状态 setChatData((state) => ({ ...state, @@ -249,7 +229,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { return { ...item, status: 'finish', - systemPrompt + quote }; }) })); @@ -260,16 +240,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { generatingMessage(); }, 100); }, - [ - chatId, - setForbidLoadChatData, - generatingMessage, - loadHistory, - modelId, - router, - setChatData, - toast - ] + [chatId, setForbidLoadChatData, generatingMessage, loadHistory, modelId, router, setChatData] ); /** @@ -717,24 +688,24 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { {item.obj === 'Human' && } {/* avatar */} - - isPc && - chatData.model.canUse && - router.push(`/model?modelId=${chatData.modelId}`) - } - : { - order: 3, - ml: ['6px', 2] - })} - > - + + + isPc && + chatData.model.canUse && + router.push(`/model?modelId=${chatData.modelId}`) + } + : { + order: 3, + ml: ['6px', 2] + })} + > { w={['20px', '34px']} h={['20px', '34px']} /> - - + + {!isPc && } {/* message */} @@ -764,7 +735,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { isChatting={isChatting && index === chatData.history.length - 1} formatLink /> - {item.systemPrompt && ( + {item.quote && item.quote.length > 0 && ( )} @@ -907,12 +878,24 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { )} {/* system prompt show modal */} { - setShowSystemPrompt('')}> + 0} onClose={() => setShowQuote([])}> - + + 知识库引用({showQuote.length}条) - - {showSystemPrompt} + + {showQuote.map((item) => ( + + {item.q} + {item.a} + + ))} diff --git a/src/pages/chat/share.tsx b/src/pages/chat/share.tsx index 8b1f72256..546c37105 100644 --- a/src/pages/chat/share.tsx +++ b/src/pages/chat/share.tsx @@ -73,7 +73,6 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => const isLeavePage = useRef(false); const [inputVal, setInputVal] = useState(''); // user input prompt - const [showSystemPrompt, setShowSystemPrompt] = useState(''); const [messageContextMenuData, setMessageContextMenuData] = useState<{ // message messageContextMenuData left: number; @@ -178,7 +177,7 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => })); // 流请求,获取数据 - const { responseText, systemPrompt } = await streamFetch({ + const { responseText } = await streamFetch({ url: '/api/chat/shareChat/chat', data: { prompts: formatPrompts.slice(-shareChatData.maxContext - 1, -1), @@ -215,8 +214,7 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => if (index !== state.history.length - 1) return item; return { ...item, - status: 'finish', - systemPrompt + status: 'finish' }; }); @@ -614,19 +612,19 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => {item.obj === 'Human' && } {/* avatar */} - - + + w={['20px', '34px']} h={['20px', '34px']} /> - - + + {!isPc && } {/* message */} @@ -656,19 +654,6 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => isChatting={isChatting && index === shareChatData.history.length - 1} formatLink /> - {item.systemPrompt && ( - - )} ) : ( @@ -796,18 +781,6 @@ const Chat = ({ shareId, historyId }: { shareId: string; historyId: string }) => )} - {/* system prompt show modal */} - { - setShowSystemPrompt('')}> - - - - - {showSystemPrompt} - - - - } {/* context menu */} {messageContextMenuData && ( diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index fef40efab..e34cdaa3e 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -1,8 +1,8 @@ -import { openaiCreateEmbedding } from '../utils/chat/openai'; import { getApiKey } from '../utils/auth'; import { openaiError2 } from '../errorCode'; import { PgClient } from '@/service/pg'; import { getErrText } from '@/utils/tools'; +import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; export async function generateVector(next = false): Promise { if (process.env.queueTask !== '1') { @@ -42,24 +42,20 @@ export async function generateVector(next = false): Promise { dataId = dataItem.id; // 获取 openapi Key - let userOpenAiKey; try { - const res = await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId }); - userOpenAiKey = res.userOpenAiKey; + await getApiKey({ model: 'gpt-3.5-turbo', userId: dataItem.userId }); } catch (err: any) { await PgClient.delete('modelData', { where: [['id', dataId]] }); - generateVector(true); getErrText(err, '获取 OpenAi Key 失败'); - return; + return generateVector(true); } // 生成词向量 - const { vectors } = await openaiCreateEmbedding({ - textArr: [dataItem.q], - userId: dataItem.userId, - userOpenAiKey + const vectors = await openaiEmbedding({ + input: [dataItem.q], + userId: dataItem.userId }); // 更新 pg 向量和状态数据 diff --git a/src/service/models/chat.ts b/src/service/models/chat.ts index 0c0b58ce6..2c70af9b0 100644 --- a/src/service/models/chat.ts +++ b/src/service/models/chat.ts @@ -47,10 +47,14 @@ const ChatSchema = new Schema({ type: String, required: true }, - systemPrompt: { - type: String, - default: '' + quote: { + type: [{ id: String, q: String, a: String }], + default: [] } + // systemPrompt: { + // type: String, + // default: '' + // } } ], default: [] diff --git a/src/service/plugins/searchKb.ts b/src/service/plugins/searchKb.ts deleted file mode 100644 index 0c5c7442f..000000000 --- a/src/service/plugins/searchKb.ts +++ /dev/null @@ -1,175 +0,0 @@ -import { PgClient } from '@/service/pg'; -import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model'; -import { ModelSchema } from '@/types/mongoSchema'; -import { openaiCreateEmbedding } from '../utils/chat/openai'; -import { ChatRoleEnum } from '@/constants/chat'; -import { modelToolMap } from '@/utils/chat'; -import { ChatItemSimpleType } from '@/types/chat'; - -/** - * use openai embedding search kb - */ -export const searchKb = async ({ - userOpenAiKey, - prompts, - similarity = 0.2, - model, - userId -}: { - userOpenAiKey?: string; - prompts: ChatItemSimpleType[]; - model: ModelSchema; - userId: string; - similarity?: number; -}): Promise<{ - code: 200 | 201; - searchPrompts: { - obj: ChatRoleEnum; - value: string; - }[]; -}> => { - async function search(textArr: string[] = []) { - const limitMap: Record = { - [ModelVectorSearchModeEnum.hightSimilarity]: 15, - [ModelVectorSearchModeEnum.noContext]: 15, - [ModelVectorSearchModeEnum.lowSimilarity]: 20 - }; - // 获取提示词的向量 - const { vectors: promptVectors } = await openaiCreateEmbedding({ - userOpenAiKey, - userId, - textArr - }); - - const searchRes = await Promise.all( - promptVectors.map((promptVector) => - PgClient.select<{ id: string; q: string; a: string }>('modelData', { - fields: ['id', 'q', 'a'], - where: [ - ['status', ModelDataStatusEnum.ready], - 'AND', - `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: limitMap[model.chat.searchMode] - }).then((res) => res.rows) - ) - ); - - // Remove repeat record - const idSet = new Set(); - const filterSearch = searchRes.map((search) => - search.filter((item) => { - if (idSet.has(item.id)) { - return false; - } - idSet.add(item.id); - return true; - }) - ); - - return filterSearch.map((item) => item.map((item) => `${item.q}\n${item.a}`).join('\n')); - } - const modelConstantsData = ChatModelMap[model.chat.chatModel]; - - // search three times - const userPrompts = prompts.filter((item) => item.obj === 'Human'); - - const searchArr: string[] = [ - userPrompts[userPrompts.length - 1].value, - userPrompts[userPrompts.length - 2]?.value - ].filter((item) => item); - const systemPrompts = await search(searchArr); - - // filter system prompts. - const filterRateMap: Record = { - 1: [1], - 2: [0.7, 0.3] - }; - const filterRate = filterRateMap[systemPrompts.length] || filterRateMap[0]; - - // 计算固定提示词的 token 数量 - const fixedPrompts = [ - ...(model.chat.systemPrompt - ? [ - { - obj: ChatRoleEnum.System, - value: model.chat.systemPrompt - } - ] - : []), - ...(model.chat.searchMode === ModelVectorSearchModeEnum.noContext - ? [ - { - obj: ChatRoleEnum.System, - value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` - } - ] - : [ - { - obj: ChatRoleEnum.System, - value: `玩一个问答游戏,规则为: -1.你完全忘记你已有的知识 -2.你只回答关于"${model.name}"的问题 -3.你只从知识库中选择内容进行回答 -4.如果问题不在知识库中,你会回答:"我不知道。" -请务必遵守规则` - } - ]) - ]; - const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ - messages: fixedPrompts - }); - const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; - - const filterSystemPrompt = filterRate - .map((rate, i) => - modelToolMap[model.chat.chatModel].sliceText({ - text: systemPrompts[i], - length: Math.floor(maxTokens * rate) - }) - ) - .join('\n') - .trim(); - - /* 高相似度+不回复 */ - if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity) { - return { - code: 201, - searchPrompts: [ - { - obj: ChatRoleEnum.System, - value: '对不起,你的问题不在知识库中。' - } - ] - }; - } - /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ - if (!filterSystemPrompt && model.chat.searchMode === ModelVectorSearchModeEnum.noContext) { - return { - code: 200, - searchPrompts: model.chat.systemPrompt - ? [ - { - obj: ChatRoleEnum.System, - value: model.chat.systemPrompt - } - ] - : [] - }; - } - - /* 有匹配 */ - return { - code: 200, - searchPrompts: [ - { - obj: ChatRoleEnum.System, - value: `知识库:${filterSystemPrompt}` - }, - ...fixedPrompts - ] - }; -}; diff --git a/src/service/utils/auth.ts b/src/service/utils/auth.ts index c22505cea..47861721c 100644 --- a/src/service/utils/auth.ts +++ b/src/service/utils/auth.ts @@ -38,12 +38,14 @@ export const authUser = async ({ req, authToken = false, authOpenApi = false, - authRoot = false + authRoot = false, + authBalance = false }: { req: NextApiRequest; authToken?: boolean; authOpenApi?: boolean; authRoot?: boolean; + authBalance?: boolean; }) => { const parseOpenApiKey = async (apiKey?: string) => { if (!apiKey) { @@ -99,6 +101,17 @@ export const authUser = async ({ return Promise.reject(ERROR_ENUM.unAuthorization); } + if (authBalance) { + const user = await User.findById(uid); + if (!user) { + return Promise.reject(ERROR_ENUM.unAuthorization); + } + + if (!user.openaiKey && formatPrice(user.balance) <= 0) { + return Promise.reject(ERROR_ENUM.insufficientQuota); + } + } + return { userId: uid }; @@ -226,7 +239,7 @@ export const authChat = async ({ req }: { modelId: string; - chatId: '' | string; + chatId?: string; req: NextApiRequest; }) => { const { userId } = await authUser({ req, authToken: true }); diff --git a/src/service/utils/chat/claude.ts b/src/service/utils/chat/claude.ts index 522f7f5b2..1a2c73341 100644 --- a/src/service/utils/chat/claude.ts +++ b/src/service/utils/chat/claude.ts @@ -1,17 +1,9 @@ import { ChatCompletionType, StreamResponseType } from './index'; import { ChatRoleEnum } from '@/constants/chat'; import axios from 'axios'; -import mongoose from 'mongoose'; -import { NEW_CHATID_HEADER } from '@/constants/chat'; /* 模型对话 */ -export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatCompletionType) => { - const conversationId = chatId || String(new mongoose.Types.ObjectId()); - // create a new chat - !chatId && - messages.filter((item) => item.obj === 'Human').length === 1 && - res?.setHeader(NEW_CHATID_HEADER, conversationId); - +export const claudChat = async ({ apiKey, messages, stream, chatId }: ChatCompletionType) => { // get system prompt const systemPrompt = messages .filter((item) => item.obj === 'System') @@ -26,7 +18,7 @@ export const claudChat = async ({ apiKey, messages, stream, chatId, res }: ChatC { prompt, stream, - conversationId + conversationId: chatId }, { headers: { @@ -55,8 +47,7 @@ export const claudStreamResponse = async ({ res, chatResponse, prompts }: Stream try { const decoder = new TextDecoder(); for await (const chunk of chatResponse.data as any) { - if (!res.writable) { - // 流被中断了,直接忽略后面的内容 + if (res.closed) { break; } const content = decoder.decode(chunk); diff --git a/src/service/utils/chat/index.ts b/src/service/utils/chat/index.ts index 895956e46..7b6073798 100644 --- a/src/service/utils/chat/index.ts +++ b/src/service/utils/chat/index.ts @@ -1,7 +1,7 @@ import { ChatItemSimpleType } from '@/types/chat'; -import { modelToolMap } from '@/utils/chat'; +import { modelToolMap } from '@/utils/plugin'; import type { ChatModelType } from '@/constants/model'; -import { ChatRoleEnum, SYSTEM_PROMPT_HEADER } from '@/constants/chat'; +import { ChatRoleEnum } from '@/constants/chat'; import { OpenAiChatEnum, ClaudeEnum } from '@/constants/model'; import { chatResponse, openAiStreamResponse } from './openai'; import { claudChat, claudStreamResponse } from './claude'; @@ -11,6 +11,7 @@ export type ChatCompletionType = { apiKey: string; temperature: number; messages: ChatItemSimpleType[]; + chatId?: string; [key: string]: any; }; export type ChatCompletionResponseType = { @@ -23,7 +24,6 @@ export type StreamResponseType = { chatResponse: any; prompts: ChatItemSimpleType[]; res: NextApiResponse; - systemPrompt?: string; [key: string]: any; }; export type StreamResponseReturnType = { @@ -129,7 +129,6 @@ export const resStreamResponse = async ({ model, res, chatResponse, - systemPrompt, prompts }: StreamResponseType & { model: ChatModelType; @@ -139,18 +138,14 @@ export const resStreamResponse = async ({ res.setHeader('Access-Control-Allow-Origin', '*'); res.setHeader('X-Accel-Buffering', 'no'); res.setHeader('Cache-Control', 'no-cache, no-transform'); - systemPrompt && res.setHeader(SYSTEM_PROMPT_HEADER, encodeURIComponent(systemPrompt)); const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[ model ].streamResponse({ chatResponse, prompts, - res, - systemPrompt + res }); - res.end(); - return { responseContent, totalTokens, finishMessages }; }; diff --git a/src/service/utils/chat/openai.ts b/src/service/utils/chat/openai.ts index 7e4559e46..918387ec9 100644 --- a/src/service/utils/chat/openai.ts +++ b/src/service/utils/chat/openai.ts @@ -1,13 +1,11 @@ import { Configuration, OpenAIApi } from 'openai'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; import { axiosConfig } from '../tools'; -import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model'; -import { pushGenerateVectorBill } from '../../events/pushBill'; -import { adaptChatItem_openAI } from '@/utils/chat/openai'; -import { modelToolMap } from '@/utils/chat'; +import { ChatModelMap, OpenAiChatEnum } from '@/constants/model'; +import { adaptChatItem_openAI } from '@/utils/plugin/openai'; +import { modelToolMap } from '@/utils/plugin'; import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index'; import { ChatRoleEnum } from '@/constants/chat'; -import { getSystemOpenAiKey } from '../auth'; export const getOpenAIApi = () => new OpenAIApi( @@ -16,51 +14,6 @@ export const getOpenAIApi = () => }) ); -/* 获取向量 */ -export const openaiCreateEmbedding = async ({ - userOpenAiKey, - userId, - textArr -}: { - userOpenAiKey?: string; - userId: string; - textArr: string[]; -}) => { - const systemAuthKey = getSystemOpenAiKey(); - - // 获取 chatAPI - const chatAPI = getOpenAIApi(); - - // 把输入的内容转成向量 - const res = await chatAPI - .createEmbedding( - { - model: embeddingModel, - input: textArr - }, - { - timeout: 60000, - ...axiosConfig(userOpenAiKey || systemAuthKey) - } - ) - .then((res) => ({ - tokenLen: res.data.usage.total_tokens || 0, - vectors: res.data.data.map((item) => item.embedding) - })); - - pushGenerateVectorBill({ - isPay: !userOpenAiKey, - userId, - text: textArr.join(''), - tokenLen: res.tokenLen - }); - - return { - vectors: res.vectors, - chatAPI - }; -}; - /* 模型对话 */ export const chatResponse = async ({ model, @@ -127,7 +80,7 @@ export const openAiStreamResponse = async ({ const content: string = json?.choices?.[0].delta.content || ''; responseContent += content; - res.writable && content && res.write(content); + !res.closed && content && res.write(content); } catch (error) { error; } @@ -137,8 +90,7 @@ export const openAiStreamResponse = async ({ const decoder = new TextDecoder(); const parser = createParser(onParse); for await (const chunk of chatResponse.data as any) { - if (!res.writable) { - // 流被中断了,直接忽略后面的内容 + if (res.closed) { break; } parser.feed(decoder.decode(chunk, { stream: true })); diff --git a/src/types/chat.d.ts b/src/types/chat.d.ts index 32053123a..988ce858b 100644 --- a/src/types/chat.d.ts +++ b/src/types/chat.d.ts @@ -1,12 +1,13 @@ import { ChatRoleEnum } from '@/constants/chat'; import type { InitChatResponse, InitShareChatResponse } from '@/api/response/chat'; +import { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch'; export type ExportChatType = 'md' | 'pdf' | 'html'; export type ChatItemSimpleType = { obj: `${ChatRoleEnum}`; value: string; - systemPrompt?: string; + quote?: QuoteItemType[]; }; export type ChatItemType = { _id: string; diff --git a/src/utils/chat/claude.ts b/src/utils/chat/claude.ts deleted file mode 100644 index 75058d93a..000000000 --- a/src/utils/chat/claude.ts +++ /dev/null @@ -1,3 +0,0 @@ -export const ClaudeSliceTextByToken = ({ text, length }: { text: string; length: number }) => { - return text.slice(0, length); -}; diff --git a/src/utils/file.ts b/src/utils/file.ts index e18e7b887..3f9627642 100644 --- a/src/utils/file.ts +++ b/src/utils/file.ts @@ -1,6 +1,6 @@ import mammoth from 'mammoth'; import Papa from 'papaparse'; -import { getOpenAiEncMap } from './chat/openai'; +import { getOpenAiEncMap } from './plugin/openai'; /** * 读取 txt 文件内容 diff --git a/src/utils/chat/index.ts b/src/utils/plugin/index.ts similarity index 64% rename from src/utils/chat/index.ts rename to src/utils/plugin/index.ts index 5a9df8ed2..a714e2005 100644 --- a/src/utils/chat/index.ts +++ b/src/utils/plugin/index.ts @@ -2,29 +2,37 @@ import { ClaudeEnum, OpenAiChatEnum } from '@/constants/model'; import type { ChatModelType } from '@/constants/model'; import type { ChatItemSimpleType } from '@/types/chat'; import { countOpenAIToken, openAiSliceTextByToken } from './openai'; -import { ClaudeSliceTextByToken } from './claude'; +import { gpt_chatItemTokenSlice } from '@/pages/api/openapi/text/gptMessagesSlice'; export const modelToolMap: Record< ChatModelType, { countTokens: (data: { messages: ChatItemSimpleType[] }) => number; sliceText: (data: { text: string; length: number }) => string; + tokenSlice: (data: { + messages: ChatItemSimpleType[]; + maxToken: number; + }) => ChatItemSimpleType[]; } > = { [OpenAiChatEnum.GPT35]: { countTokens: ({ messages }) => countOpenAIToken({ model: OpenAiChatEnum.GPT35, messages }), - sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT35, ...data }) + sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT35, ...data }), + tokenSlice: (data) => gpt_chatItemTokenSlice({ model: OpenAiChatEnum.GPT35, ...data }) }, [OpenAiChatEnum.GPT4]: { countTokens: ({ messages }) => countOpenAIToken({ model: OpenAiChatEnum.GPT4, messages }), - sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT4, ...data }) + sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT4, ...data }), + tokenSlice: (data) => gpt_chatItemTokenSlice({ model: OpenAiChatEnum.GPT4, ...data }) }, [OpenAiChatEnum.GPT432k]: { countTokens: ({ messages }) => countOpenAIToken({ model: OpenAiChatEnum.GPT432k, messages }), - sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT432k, ...data }) + sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT432k, ...data }), + tokenSlice: (data) => gpt_chatItemTokenSlice({ model: OpenAiChatEnum.GPT432k, ...data }) }, [ClaudeEnum.Claude]: { countTokens: ({ messages }) => countOpenAIToken({ model: OpenAiChatEnum.GPT35, messages }), - sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT35, ...data }) + sliceText: (data) => openAiSliceTextByToken({ model: OpenAiChatEnum.GPT35, ...data }), + tokenSlice: (data) => gpt_chatItemTokenSlice({ model: OpenAiChatEnum.GPT35, ...data }) } }; diff --git a/src/utils/chat/openai.ts b/src/utils/plugin/openai.ts similarity index 100% rename from src/utils/chat/openai.ts rename to src/utils/plugin/openai.ts diff --git a/src/utils/tools.ts b/src/utils/tools.ts index 94468353d..2baff1110 100644 --- a/src/utils/tools.ts +++ b/src/utils/tools.ts @@ -126,3 +126,10 @@ export const getErrText = (err: any, def = '') => { msg && console.log('error =>', msg); return msg; }; + +export const delay = (ms: number) => + new Promise((resolve) => { + setTimeout(() => { + resolve(''); + }, ms); + });