From 78762498eb4cfbb46cfe722fdc8d3b0d95e7ef6b Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 29 Apr 2023 15:55:47 +0800 Subject: [PATCH] perf: model framwork --- src/api/model.ts | 3 +- src/api/response/chat.d.ts | 3 +- src/constants/model.ts | 54 ++--- src/pages/api/chat/{chatGpt.ts => chat.ts} | 82 +++++-- src/pages/api/chat/init.ts | 3 +- src/pages/api/chat/vectorGpt.ts | 189 ---------------- src/pages/api/model/create.ts | 28 +-- src/pages/api/model/update.ts | 9 +- src/pages/api/openapi/chat/chat.ts | 202 ++++++++++++++++++ src/pages/api/openapi/chat/chatGpt.ts | 23 +- src/pages/api/openapi/chat/lafGpt.ts | 71 +++--- src/pages/api/openapi/chat/vectorGpt.ts | 92 ++++---- src/pages/chat/components/SlideBar.tsx | 2 +- src/pages/chat/index.tsx | 27 +-- .../model/detail/components/ModelDataCard.tsx | 5 +- .../model/detail/components/ModelEditForm.tsx | 82 +++---- src/pages/model/detail/index.tsx | 29 +-- .../model/list/components/CreateModel.tsx | 138 ------------ .../model/list/components/ModelPhoneList.tsx | 10 +- .../model/list/components/ModelTable.tsx | 13 +- src/pages/model/list/index.tsx | 39 ++-- src/service/events/pushBill.ts | 20 +- src/service/models/bill.ts | 4 +- src/service/models/model.ts | 60 +++--- src/service/tools/searchKb.ts | 47 ++++ src/service/utils/auth.ts | 87 +++++++- src/service/utils/sendNote.ts | 20 -- src/service/utils/tools.ts | 42 ---- src/types/model.d.ts | 6 +- src/types/mongoSchema.d.ts | 16 +- 30 files changed, 649 insertions(+), 757 deletions(-) rename src/pages/api/chat/{chatGpt.ts => chat.ts} (54%) delete mode 100644 src/pages/api/chat/vectorGpt.ts create mode 100644 src/pages/api/openapi/chat/chat.ts delete mode 100644 src/pages/model/list/components/CreateModel.tsx create mode 100644 src/service/tools/searchKb.ts diff --git a/src/api/model.ts b/src/api/model.ts index b4ec38617..05e0b1ca4 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -12,8 +12,7 @@ export const getMyModels = () => GET('/model/list'); /** * 创建一个模型 */ -export const postCreateModel = (data: { name: string; serviceModelName: string }) => - POST('/model/create', data); +export const postCreateModel = (data: { name: string }) => POST('/model/create', data); /** * 根据 ID 删除模型 diff --git a/src/api/response/chat.d.ts b/src/api/response/chat.d.ts index 21b5725bd..bc8906dc8 100644 --- a/src/api/response/chat.d.ts +++ b/src/api/response/chat.d.ts @@ -7,7 +7,6 @@ export type InitChatResponse = { name: string; avatar: string; intro: string; - chatModel: ModelSchema.service.chatModel; // 对话模型名 - modelName: ModelSchema.service.modelName; // 底层模型 + chatModel: ModelSchema['chat']['chatModel']; // 对话模型名 history: ChatItemType[]; }; diff --git a/src/constants/model.ts b/src/constants/model.ts index d71f0fe46..bf14d404a 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -1,50 +1,32 @@ import type { ModelSchema } from '@/types/mongoSchema'; export const embeddingModel = 'text-embedding-ada-002'; + export enum ChatModelEnum { 'GPT35' = 'gpt-3.5-turbo', 'GPT4' = 'gpt-4', 'GPT432k' = 'gpt-4-32k' } - -export enum ModelNameEnum { - GPT35 = 'gpt-3.5-turbo', - VECTOR_GPT = 'VECTOR_GPT' -} - -export const Model2ChatModelMap: Record<`${ModelNameEnum}`, `${ChatModelEnum}`> = { - [ModelNameEnum.GPT35]: 'gpt-3.5-turbo', - [ModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo' +export const ChatModelMap = { + // ui name + [ChatModelEnum.GPT35]: 'ChatGpt', + [ChatModelEnum.GPT4]: 'Gpt4', + [ChatModelEnum.GPT432k]: 'Gpt4-32k' }; -export type ModelConstantsData = { - icon: 'model' | 'dbModel'; - name: string; - model: `${ModelNameEnum}`; - trainName: string; // 空字符串代表不能训练 +export type ChatModelConstantType = { + chatModel: `${ChatModelEnum}`; contextMaxToken: number; maxTemperature: number; price: number; // 多少钱 / 1token,单位: 0.00001元 }; -export const modelList: ModelConstantsData[] = [ +export const modelList: ChatModelConstantType[] = [ { - icon: 'model', - name: 'chatGPT', - model: ModelNameEnum.GPT35, - trainName: '', + chatModel: ChatModelEnum.GPT35, contextMaxToken: 4096, maxTemperature: 1.5, price: 3 - }, - { - icon: 'dbModel', - name: '知识库', - model: ModelNameEnum.VECTOR_GPT, - trainName: 'vector', - contextMaxToken: 4096, - maxTemperature: 1, - price: 3 } ]; @@ -115,14 +97,16 @@ export const ModelVectorSearchModeMap: Record< export const defaultModel: ModelSchema = { _id: 'modelId', userId: 'userId', - name: 'modelName', + name: '模型名称', avatar: '/icon/logo.png', status: ModelStatusEnum.pending, updateTime: Date.now(), - systemPrompt: '', - temperature: 5, - search: { - mode: ModelVectorSearchModeEnum.hightSimilarity + chat: { + useKb: false, + searchMode: ModelVectorSearchModeEnum.hightSimilarity, + systemPrompt: '', + temperature: 0, + chatModel: ChatModelEnum.GPT35 }, share: { isShare: false, @@ -130,10 +114,6 @@ export const defaultModel: ModelSchema = { intro: '', collection: 0 }, - service: { - chatModel: ModelNameEnum.GPT35, - modelName: ModelNameEnum.GPT35 - }, security: { domain: ['*'], contextMaxLen: 1, diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chat.ts similarity index 54% rename from src/pages/api/chat/chatGpt.ts rename to src/pages/api/chat/chat.ts index b4c2c241b..48302dc47 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chat.ts @@ -1,13 +1,14 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; import { getOpenAIApi, authChat } from '@/service/utils/auth'; -import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; +import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList } from '@/constants/model'; +import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; +import { searchKb_openai } from '@/service/tools/searchKb'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -46,7 +47,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) authorization }); - const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); + const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); if (!modelConstantsData) { throw new Error('模型加载异常'); } @@ -54,31 +55,84 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 读取对话内容 const prompts = [...content, prompt]; - // 如果有系统提示词,自动插入 - if (model.systemPrompt) { - prompts.unshift({ - obj: 'SYSTEM', - value: model.systemPrompt + // 使用了知识库搜索 + if (model.chat.useKb) { + const { systemPrompts } = await searchKb_openai({ + apiKey: userApiKey || systemKey, + isPay: !userApiKey, + text: prompt.value, + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, + modelId, + userId }); + + // filter system prompt + if ( + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity + ) { + return res.send('对不起,你的问题不在知识库中。'); + } + /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ + if ( + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.noContext + ) { + prompts.unshift({ + obj: 'SYSTEM', + value: model.chat.systemPrompt + }); + } else { + // 有匹配情况下,system 添加知识库内容。 + // 系统提示词过滤,最多 2500 tokens + const filterSystemPrompt = systemPromptFilter({ + model: model.chat.chatModel, + prompts: systemPrompts, + maxTokens: 2500 + }); + + prompts.unshift({ + obj: 'SYSTEM', + value: ` + ${model.chat.systemPrompt} + ${ + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity + ? `不回答知识库外的内容.` + : '' + } + 知识库内容为: ${filterSystemPrompt}' + ` + }); + } + } else { + // 没有用知识库搜索,仅用系统提示词 + if (model.chat.systemPrompt) { + prompts.unshift({ + obj: 'SYSTEM', + value: model.chat.systemPrompt + }); + } } - // 控制在 tokens 数量,防止超出 + // 控制总 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ - model: model.service.chatModel, + model: model.chat.chatModel, prompts, maxTokens: modelConstantsData.contextMaxToken - 500 }); // 计算温度 - const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); + const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed( + 2 + ); // console.log(filterPrompts); // 获取 chatAPI const chatAPI = getOpenAIApi(userApiKey || systemKey); // 发出请求 const chatResponse = await chatAPI.createChatCompletion( { - model: model.service.chatModel, - temperature, + model: model.chat.chatModel, + temperature: Number(temperature) || 0, messages: filterPrompts, frequency_penalty: 0.5, // 越大,重复内容越少 presence_penalty: -0.5, // 越大,越容易出现新内容 @@ -105,7 +159,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 只有使用平台的 key 才计费 pushChatBill({ isPay: !userApiKey, - modelName: model.service.modelName, + chatModel: model.chat.chatModel, userId, chatId, messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index f44bc42da..ba1569007 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -59,8 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) name: model.name, avatar: model.avatar, intro: model.share.intro, - modelName: model.service.modelName, - chatModel: model.service.chatModel, + chatModel: model.chat.chatModel, history } }); diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts deleted file mode 100644 index a088dca27..000000000 --- a/src/pages/api/chat/vectorGpt.ts +++ /dev/null @@ -1,189 +0,0 @@ -import type { NextApiRequest, NextApiResponse } from 'next'; -import { connectToDatabase } from '@/service/mongo'; -import { authChat } from '@/service/utils/auth'; -import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; -import { ChatItemType } from '@/types/chat'; -import { jsonRes } from '@/service/response'; -import { PassThrough } from 'stream'; -import { - modelList, - ModelVectorSearchModeMap, - ModelVectorSearchModeEnum, - ModelDataStatusEnum -} from '@/constants/model'; -import { pushChatBill } from '@/service/events/pushBill'; -import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; -import dayjs from 'dayjs'; -import { PgClient } from '@/service/pg'; - -/* 发送提示词 */ -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - let step = 0; // step=1时,表示开始了流响应 - const stream = new PassThrough(); - stream.on('error', () => { - console.log('error: ', 'stream error'); - stream.destroy(); - }); - res.on('close', () => { - stream.destroy(); - }); - res.on('error', () => { - console.log('error: ', 'request error'); - stream.destroy(); - }); - - try { - const { modelId, chatId, prompt } = req.body as { - modelId: string; - chatId: '' | string; - prompt: ChatItemType; - }; - - const { authorization } = req.headers; - if (!modelId || !prompt) { - throw new Error('缺少参数'); - } - - await connectToDatabase(); - let startTime = Date.now(); - - const { model, content, userApiKey, systemKey, userId } = await authChat({ - modelId, - chatId, - authorization - }); - - const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); - if (!modelConstantsData) { - throw new Error('模型加载异常'); - } - - // 读取对话内容 - const prompts = [...content, prompt]; - - // 获取提示词的向量 - const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ - isPay: !userApiKey, - apiKey: userApiKey || systemKey, - userId, - text: prompt.value - }); - - // 相似度搜素 - const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; - const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { - fields: ['id', 'q', 'a'], - where: [ - ['status', ModelDataStatusEnum.ready], - 'AND', - ['model_id', model._id], - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: 20 - }); - - const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); - - /* 高相似度+退出,无法匹配时直接退出 */ - if ( - formatRedisPrompt.length === 0 && - model.search.mode === ModelVectorSearchModeEnum.hightSimilarity - ) { - return res.send('对不起,你的问题不在知识库中。'); - } - /* 高相似度+无上下文,不添加额外知识 */ - if ( - formatRedisPrompt.length === 0 && - model.search.mode === ModelVectorSearchModeEnum.noContext - ) { - prompts.unshift({ - obj: 'SYSTEM', - value: model.systemPrompt - }); - } else { - // 有匹配情况下,system 添加知识库内容。 - // 系统提示词过滤,最多 2500 tokens - const systemPrompt = systemPromptFilter({ - model: model.service.chatModel, - prompts: formatRedisPrompt, - maxTokens: 2500 - }); - - prompts.unshift({ - obj: 'SYSTEM', - value: ` -${model.systemPrompt} -${ - model.search.mode === ModelVectorSearchModeEnum.hightSimilarity - ? `你只能从知识库选择内容回答.不在知识库内容拒绝回复` - : '' -} -知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}' -` - }); - } - - // 控制在 tokens 数量,防止超出 - const filterPrompts = openaiChatFilter({ - model: model.service.chatModel, - prompts, - maxTokens: modelConstantsData.contextMaxToken - 500 - }); - - // console.log(filterPrompts); - // 计算温度 - const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); - - // 发出请求 - const chatResponse = await chatAPI.createChatCompletion( - { - model: model.service.chatModel, - temperature, - messages: filterPrompts, - frequency_penalty: 0.5, // 越大,重复内容越少 - presence_penalty: -0.5, // 越大,越容易出现新内容 - stream: true, - stop: ['.!?。'] - }, - { - timeout: 40000, - responseType: 'stream', - ...axiosConfig() - } - ); - - console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); - - step = 1; - - const { responseContent } = await gpt35StreamResponse({ - res, - stream, - chatResponse - }); - - // 只有使用平台的 key 才计费 - pushChatBill({ - isPay: !userApiKey, - modelName: model.service.modelName, - userId, - chatId, - messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) - }); - // jsonRes(res); - } catch (err: any) { - if (step === 1) { - // 直接结束流 - console.log('error,结束'); - stream.destroy(); - } else { - res.status(500); - jsonRes(res, { - code: 500, - error: err - }); - } - } -} diff --git a/src/pages/api/model/create.ts b/src/pages/api/model/create.ts index ad2e08035..f62abf2eb 100644 --- a/src/pages/api/model/create.ts +++ b/src/pages/api/model/create.ts @@ -3,14 +3,13 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model'; +import { ModelStatusEnum } from '@/constants/model'; import { Model } from '@/service/models/model'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { name, serviceModelName } = req.body as { + const { name } = req.body as { name: string; - serviceModelName: `${ModelNameEnum}`; }; const { authorization } = req.headers; @@ -18,45 +17,32 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - if (!name || !serviceModelName) { + if (!name) { throw new Error('缺少参数'); } // 凭证校验 const userId = await authToken(authorization); - const modelItem = modelList.find((item) => item.model === serviceModelName); - - if (!modelItem) { - throw new Error('模型不存在'); - } - await connectToDatabase(); // 上限校验 const authCount = await Model.countDocuments({ userId }); - if (authCount >= 20) { - throw new Error('上限 20 个模型'); + if (authCount >= 30) { + throw new Error('上限 30 个模型'); } // 创建模型 const response = await Model.create({ name, userId, - status: ModelStatusEnum.running, - service: { - chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型 - modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作 - } + status: ModelStatusEnum.running }); - // 根据 id 获取模型信息 - const model = await Model.findById(response._id); - jsonRes(res, { - data: model + data: response._id }); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/model/update.ts b/src/pages/api/model/update.ts index 31ef0eb3e..5c9645ab9 100644 --- a/src/pages/api/model/update.ts +++ b/src/pages/api/model/update.ts @@ -9,8 +9,7 @@ import { authModel } from '@/service/utils/auth'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { name, avatar, search, share, service, security, systemPrompt, temperature } = - req.body as ModelUpdateParams; + const { name, avatar, chat, share, security } = req.body as ModelUpdateParams; const { modelId } = req.query as { modelId: string }; const { authorization } = req.headers; @@ -18,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - if (!name || !service || !security || !modelId) { + if (!name || !chat || !security || !modelId) { throw new Error('参数错误'); } @@ -41,12 +40,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< { name, avatar, - systemPrompt, - temperature, + chat, 'share.isShare': share.isShare, 'share.isShareDetail': share.isShareDetail, 'share.intro': share.intro, - search, security } ); diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts new file mode 100644 index 000000000..d7461a65b --- /dev/null +++ b/src/pages/api/openapi/chat/chat.ts @@ -0,0 +1,202 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { connectToDatabase } from '@/service/mongo'; +import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth'; +import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { ChatItemType } from '@/types/chat'; +import { jsonRes } from '@/service/response'; +import { PassThrough } from 'stream'; +import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; +import { pushChatBill } from '@/service/events/pushBill'; +import { gpt35StreamResponse } from '@/service/utils/openai'; +import { searchKb_openai } from '@/service/tools/searchKb'; + +/* 发送提示词 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); + + try { + const { + prompts, + modelId, + isStream = true + } = req.body as { + prompts: ChatItemType[]; + modelId: string; + isStream: boolean; + }; + + if (!prompts || !modelId) { + throw new Error('缺少参数'); + } + if (!Array.isArray(prompts)) { + throw new Error('prompts is not array'); + } + if (prompts.length > 30 || prompts.length === 0) { + throw new Error('prompts length range 1-30'); + } + + await connectToDatabase(); + let startTime = Date.now(); + + /* 凭证校验 */ + const { apiKey, userId } = await authOpenApiKey(req); + + const { model } = await authModel({ + userId, + modelId + }); + + const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); + if (!modelConstantsData) { + throw new Error('模型加载异常'); + } + + // 使用了知识库搜索 + if (model.chat.useKb) { + const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22; + + const { systemPrompts } = await searchKb_openai({ + apiKey, + isPay: true, + text: prompts[prompts.length - 1].value, + similarity, + modelId, + userId + }); + + // filter system prompt + if ( + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity + ) { + return jsonRes(res, { + code: 500, + message: '对不起,你的问题不在知识库中。', + data: '对不起,你的问题不在知识库中。' + }); + } + /* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */ + if ( + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.noContext + ) { + prompts.unshift({ + obj: 'SYSTEM', + value: model.chat.systemPrompt + }); + } else { + // 有匹配情况下,system 添加知识库内容。 + // 系统提示词过滤,最多 2500 tokens + const filterSystemPrompt = systemPromptFilter({ + model: model.chat.chatModel, + prompts: systemPrompts, + maxTokens: 2500 + }); + + prompts.unshift({ + obj: 'SYSTEM', + value: ` + ${model.chat.systemPrompt} + ${ + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity + ? `不回答知识库外的内容.` + : '' + } + 知识库内容为: ${filterSystemPrompt}' + ` + }); + } + } else { + // 没有用知识库搜索,仅用系统提示词 + if (model.chat.systemPrompt) { + prompts.unshift({ + obj: 'SYSTEM', + value: model.chat.systemPrompt + }); + } + } + + // 控制总 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter({ + model: model.chat.chatModel, + prompts, + maxTokens: modelConstantsData.contextMaxToken - 500 + }); + + // 计算温度 + const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed( + 2 + ); + // console.log(filterPrompts); + // 获取 chatAPI + const chatAPI = getOpenAIApi(apiKey); + // 发出请求 + const chatResponse = await chatAPI.createChatCompletion( + { + model: model.chat.chatModel, + temperature: Number(temperature) || 0, + messages: filterPrompts, + frequency_penalty: 0.5, // 越大,重复内容越少 + presence_penalty: -0.5, // 越大,越容易出现新内容 + stream: isStream, + stop: ['.!?。'] + }, + { + timeout: 180000, + responseType: isStream ? 'stream' : 'json', + ...axiosConfig() + } + ); + + console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); + + let responseContent = ''; + + if (isStream) { + step = 1; + const streamResponse = await gpt35StreamResponse({ + res, + stream, + chatResponse + }); + responseContent = streamResponse.responseContent; + } else { + responseContent = chatResponse.data.choices?.[0]?.message?.content || ''; + jsonRes(res, { + data: responseContent + }); + } + + // 只有使用平台的 key 才计费 + pushChatBill({ + isPay: true, + chatModel: model.chat.chatModel, + userId, + messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) + }); + } catch (err: any) { + if (step === 1) { + // 直接结束流 + console.log('error,结束'); + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } + } +} diff --git a/src/pages/api/openapi/chat/chatGpt.ts b/src/pages/api/openapi/chat/chatGpt.ts index e346021d7..080fb6577 100644 --- a/src/pages/api/openapi/chat/chatGpt.ts +++ b/src/pages/api/openapi/chat/chatGpt.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/auth'; -import { axiosConfig, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools'; +import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth'; +import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; @@ -60,37 +60,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('无权使用该模型'); } - const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); + const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); if (!modelConstantsData) { throw new Error('模型加载异常'); } // 如果有系统提示词,自动插入 - if (model.systemPrompt) { + if (model.chat.systemPrompt) { prompts.unshift({ obj: 'SYSTEM', - value: model.systemPrompt + value: model.chat.systemPrompt }); } // 控制在 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ - model: model.service.chatModel, + model: model.chat.chatModel, prompts, maxTokens: modelConstantsData.contextMaxToken - 500 }); // console.log(filterPrompts); // 计算温度 - const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); - + const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed( + 2 + ); // 获取 chatAPI const chatAPI = getOpenAIApi(apiKey); // 发出请求 const chatResponse = await chatAPI.createChatCompletion( { - model: model.service.chatModel, - temperature, + model: model.chat.chatModel, + temperature: Number(temperature) || 0, messages: filterPrompts, frequency_penalty: 0.5, // 越大,重复内容越少 presence_penalty: -0.5, // 越大,越容易出现新内容 @@ -126,7 +127,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 只有使用平台的 key 才计费 pushChatBill({ isPay: true, - modelName: model.service.modelName, + chatModel: model.chat.chatModel, userId, messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) }); diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 549cb08cd..a48798942 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -1,20 +1,14 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/auth'; -import { authOpenApiKey } from '@/service/utils/tools'; +import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth'; import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { - ModelNameEnum, - modelList, - ModelVectorSearchModeMap, - ChatModelEnum -} from '@/constants/model'; +import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; -import { PgClient } from '@/service/pg'; +import { gpt35StreamResponse } from '@/service/utils/openai'; +import { searchKb_openai } from '@/service/tools/searchKb'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -59,10 +53,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('找不到模型'); } - const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT); + const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); if (!modelConstantsData) { - throw new Error('模型已下架'); + throw new Error('model is undefined'); } + console.log('laf gpt start'); // 获取 chatAPI @@ -132,62 +127,48 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompt.value += ` ${promptResolve}`; console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`); - // 获取提示词的向量 - const { vector: promptVector } = await openaiCreateEmbedding({ - isPay: true, - apiKey, - userId, - text: prompt.value - }); - // 读取对话内容 const prompts = [prompt]; - // 相似度搜索 - const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; - const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { - fields: ['id', 'q', 'a'], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - where: [ - ['model_id', model._id], - 'AND', - ['user_id', userId], - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - limit: 30 + // 获取向量匹配到的提示词 + const { systemPrompts } = await searchKb_openai({ + isPay: true, + apiKey, + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, + text: prompt.value, + modelId, + userId }); - const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); - // system 筛选,最多 2500 tokens - const systemPrompt = systemPromptFilter({ - model: model.service.chatModel, - prompts: formatRedisPrompt, + const filterSystemPrompt = systemPromptFilter({ + model: model.chat.chatModel, + prompts: systemPrompts, maxTokens: 2500 }); prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}` + value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}` }); // 控制上下文 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ - model: model.service.chatModel, + model: model.chat.chatModel, prompts, maxTokens: modelConstantsData.contextMaxToken - 500 }); // console.log(filterPrompts); // 计算温度 - const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); - + const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed( + 2 + ); // 发出请求 const chatResponse = await chatAPI.createChatCompletion( { - model: model.service.chatModel, - temperature, + model: model.chat.chatModel, + temperature: Number(temperature) || 0, messages: filterPrompts, frequency_penalty: 0.5, // 越大,重复内容越少 presence_penalty: -0.5, // 越大,越容易出现新内容 @@ -223,7 +204,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) pushChatBill({ isPay: true, - modelName: model.service.modelName, + chatModel: model.chat.chatModel, userId, messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) }); diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts index 1b9997ddf..9431d538b 100644 --- a/src/pages/api/openapi/chat/vectorGpt.ts +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -1,24 +1,14 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { - axiosConfig, - systemPromptFilter, - authOpenApiKey, - openaiChatFilter -} from '@/service/utils/tools'; +import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; +import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { - modelList, - ModelVectorSearchModeMap, - ModelVectorSearchModeEnum, - ModelDataStatusEnum -} from '@/constants/model'; +import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; -import dayjs from 'dayjs'; -import { PgClient } from '@/service/pg'; +import { gpt35StreamResponse } from '@/service/utils/openai'; +import { searchKb_openai } from '@/service/tools/searchKb'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -72,96 +62,86 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('无权使用该模型'); } - const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName); + const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); if (!modelConstantsData) { throw new Error('模型初始化异常'); } - // 获取提示词的向量 - const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ + // 获取向量匹配到的提示词 + const { systemPrompts } = await searchKb_openai({ isPay: true, apiKey, - userId, - text: prompts[prompts.length - 1].value // 取最后一个 + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, + text: prompts[prompts.length - 1].value, + modelId, + userId }); - // 相似度搜素 - const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; - const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { - fields: ['id', 'q', 'a'], - where: [ - ['status', ModelDataStatusEnum.ready], - 'AND', - ['model_id', model._id], - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: 20 - }); - - const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); - // system 合并 if (prompts[0].obj === 'SYSTEM') { - formatRedisPrompt.unshift(prompts.shift()?.value || ''); + systemPrompts.unshift(prompts.shift()?.value || ''); } /* 高相似度+退出,无法匹配时直接退出 */ if ( - formatRedisPrompt.length === 0 && - model.search.mode === ModelVectorSearchModeEnum.hightSimilarity + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ) { - return res.send('对不起,你的问题不在知识库中。'); + return jsonRes(res, { + code: 500, + message: '对不起,你的问题不在知识库中。', + data: '对不起,你的问题不在知识库中。' + }); } /* 高相似度+无上下文,不添加额外知识 */ if ( - formatRedisPrompt.length === 0 && - model.search.mode === ModelVectorSearchModeEnum.noContext + systemPrompts.length === 0 && + model.chat.searchMode === ModelVectorSearchModeEnum.noContext ) { prompts.unshift({ obj: 'SYSTEM', - value: model.systemPrompt + value: model.chat.systemPrompt }); } else { // 有匹配或者低匹配度模式情况下,添加知识库内容。 // 系统提示词过滤,最多 2500 tokens const systemPrompt = systemPromptFilter({ - model: model.service.chatModel, - prompts: formatRedisPrompt, + model: model.chat.chatModel, + prompts: systemPrompts, maxTokens: 2500 }); prompts.unshift({ obj: 'SYSTEM', value: ` -${model.systemPrompt} +${model.chat.systemPrompt} ${ - model.search.mode === ModelVectorSearchModeEnum.hightSimilarity - ? `你只能从知识库选择内容回答.不在知识库内容拒绝回复` - : '' + model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : '' } -知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}' +知识库内容为: ${systemPrompt}' ` }); } // 控制在 tokens 数量,防止超出 const filterPrompts = openaiChatFilter({ - model: model.service.chatModel, + model: model.chat.chatModel, prompts, maxTokens: modelConstantsData.contextMaxToken - 500 }); // console.log(filterPrompts); // 计算温度 - const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); + const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed( + 2 + ); + const chatAPI = getOpenAIApi(apiKey); // 发出请求 const chatResponse = await chatAPI.createChatCompletion( { - model: model.service.chatModel, - temperature, + model: model.chat.chatModel, + temperature: Number(temperature) || 0, messages: filterPrompts, frequency_penalty: 0.5, // 越大,重复内容越少 presence_penalty: -0.5, // 越大,越容易出现新内容 @@ -196,7 +176,7 @@ ${ pushChatBill({ isPay: true, - modelName: model.service.modelName, + chatModel: model.chat.chatModel, userId, messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) }); diff --git a/src/pages/chat/components/SlideBar.tsx b/src/pages/chat/components/SlideBar.tsx index 303deb127..7bdf0ac7b 100644 --- a/src/pages/chat/components/SlideBar.tsx +++ b/src/pages/chat/components/SlideBar.tsx @@ -52,7 +52,7 @@ const SlideBar = ({ const myModelList = myModels.map((item) => ({ id: item._id, name: item.name, - icon: modelList.find((model) => model.model === item?.service?.modelName)?.icon || 'model' + icon: 'model' as any })); const collectionList = collectionModels .map((item) => ({ diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 7bee7fd1a..759ffdcd7 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -1,6 +1,5 @@ import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import { useRouter } from 'next/router'; -import Image from 'next/image'; import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat'; import type { InitChatResponse } from '@/api/response/chat'; import type { ChatItemType } from '@/types/chat'; @@ -16,12 +15,13 @@ import { Menu, MenuButton, MenuList, - MenuItem + MenuItem, + Image } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useScreen } from '@/hooks/useScreen'; import { useQuery } from '@tanstack/react-query'; -import { ModelNameEnum } from '@/constants/model'; +import { ChatModelEnum } from '@/constants/model'; import dynamic from 'next/dynamic'; import { useGlobalStore } from '@/store/global'; import { useCopyData } from '@/utils/tools'; @@ -65,8 +65,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { name: '', avatar: '/icon/logo.png', intro: '', - chatModel: '', - modelName: '', + chatModel: ChatModelEnum.GPT35, history: [] }); // 聊天框整体数据 @@ -193,13 +192,6 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { // gpt 对话 const gptChatPrompt = useCallback( async (prompts: ChatSiteItemType) => { - const urlMap: Record = { - [ModelNameEnum.GPT35]: '/api/chat/chatGpt', - [ModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt' - }; - - if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型'); - // create abort obj const abortSignal = new AbortController(); controller.current = abortSignal; @@ -212,7 +204,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { // 流请求,获取数据 const responseText = await streamFetch({ - url: urlMap[chatData.modelName], + url: '/api/chat/chat', data: { prompt, chatId, @@ -278,7 +270,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { }) })); }, - [chatData.modelName, chatId, generatingMessage, modelId, router, toast] + [chatId, generatingMessage, modelId, router, toast] ); /** @@ -393,7 +385,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { // 更新流中断对象 useEffect(() => { return () => { - // eslint-disable-next-line react-hooks/exhaustive-deps + isResetPage.current = true; controller.current?.abort(); }; }, []); @@ -476,8 +468,9 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { : chatData.avatar || '/icon/logo.png' } alt="avatar" - width={media(30, 20)} - height={media(30, 20)} + w={['20px', '30px']} + maxH={'50px'} + objectFit={'contain'} /> diff --git a/src/pages/model/detail/components/ModelDataCard.tsx b/src/pages/model/detail/components/ModelDataCard.tsx index c4ca40137..edaef086d 100644 --- a/src/pages/model/detail/components/ModelDataCard.tsx +++ b/src/pages/model/detail/components/ModelDataCard.tsx @@ -45,9 +45,10 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean const [searchText, setSearchText] = useState(''); const tdStyles = useRef({ fontSize: 'xs', + minW: '150px', maxW: '500px', - whiteSpace: 'pre-wrap', maxH: '250px', + whiteSpace: 'pre-wrap', overflowY: 'auto' }); const { @@ -132,7 +133,7 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean <> - 模型数据: {total}组 + 知识库数据: {total}组 {isOwner && ( <> diff --git a/src/pages/model/detail/components/ModelEditForm.tsx b/src/pages/model/detail/components/ModelEditForm.tsx index 46c8e78d5..4c17f8f57 100644 --- a/src/pages/model/detail/components/ModelEditForm.tsx +++ b/src/pages/model/detail/components/ModelEditForm.tsx @@ -21,7 +21,7 @@ import { import { QuestionOutlineIcon } from '@chakra-ui/icons'; import type { ModelSchema } from '@/types/mongoSchema'; import { UseFormReturn } from 'react-hook-form'; -import { modelList, ModelVectorSearchModeMap } from '@/constants/model'; +import { ChatModelMap, modelList, ModelVectorSearchModeMap } from '@/constants/model'; import { formatPrice } from '@/utils/user'; import { useConfirm } from '@/hooks/useConfirm'; import { useSelectFile } from '@/hooks/useSelectFile'; @@ -30,12 +30,10 @@ import { fileToBase64 } from '@/utils/file'; const ModelEditForm = ({ formHooks, - canTrain, isOwner, handleDelModel }: { formHooks: UseFormReturn; - canTrain: boolean; isOwner: boolean; handleDelModel: () => void; }) => { @@ -73,6 +71,12 @@ const ModelEditForm = ({ <> 基本信息 + + + modelId: + + {getValues('_id')} + 头像: @@ -101,17 +105,12 @@ const ModelEditForm = ({ > + - modelId: + 对话模型: - {getValues('_id')} - - - - 模型类型: - - {modelList.find((item) => item.model === getValues('service.modelName'))?.name} + {ChatModelMap[getValues('chat.chatModel')]} @@ -119,7 +118,7 @@ const ModelEditForm = ({ {formatPrice( - modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0, + modelList.find((item) => item.chatModel === getValues('chat.chatModel'))?.price || 0, 1000 )} 元/1K tokens(包括上下文和回答) @@ -163,15 +162,15 @@ const ModelEditForm = ({ min={0} max={10} step={1} - value={getValues('temperature')} + value={getValues('chat.temperature')} isDisabled={!isOwner} onChange={(e) => { - setValue('temperature', e); + setValue('chat.temperature', e); setRefresh(!refresh); }} > - {getValues('temperature')} + {getValues('chat.temperature')} @@ -190,35 +189,42 @@ const ModelEditForm = ({ - {canTrain && ( - - - 搜索模式 - - - + + 知识库搜索 + { + setValue('chat.useKb', !getValues('chat.useKb')); + setRefresh(!refresh); + }} + /> + + {getValues('chat.useKb') && ( + + + 搜索模式  + + + )} + 系统提示词