From 52d00d0562445af859e3360849371a275cffe61b Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 8 Apr 2023 20:27:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=9F=A5=E8=AF=86=E5=BA=93=E5=AF=B9?= =?UTF-8?q?=E5=A4=96api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/components/Layout/index.tsx | 12 +- src/pages/api/chat/vectorGpt.ts | 4 +- src/pages/api/openapi/chat/lafGpt.ts | 2 +- src/pages/api/openapi/chat/vectorGpt.ts | 210 ++++++++++++++++++++++++ src/pages/chat/index.tsx | 9 +- src/service/errorCode.ts | 3 +- src/service/response.ts | 7 +- 7 files changed, 230 insertions(+), 17 deletions(-) create mode 100644 src/pages/api/openapi/chat/vectorGpt.ts diff --git a/src/components/Layout/index.tsx b/src/components/Layout/index.tsx index 6c5f6f5c7..3f3afdf44 100644 --- a/src/components/Layout/index.tsx +++ b/src/components/Layout/index.tsx @@ -31,13 +31,13 @@ const navbarList = [ icon: 'user', link: '/number/setting', activeLink: ['/number/setting'] + }, + { + label: '开发', + icon: 'develop', + link: '/openapi', + activeLink: ['/openapi'] } - // { - // label: '开发', - // icon: 'develop', - // link: '/openapi', - // activeLink: ['/openapi'] - // } ]; const Layout = ({ children }: { children: JSX.Element }) => { diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 2494d6fda..64a70d16d 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -82,14 +82,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) vectorToBuffer(promptVector), 'LIMIT', '0', - '20', + '30', 'DIALECT', '2' ]); const formatRedisPrompt: string[] = []; // 格式化响应值,获取 qa - for (let i = 2; i < 42; i += 2) { + for (let i = 2; i < 61; i += 2) { const text = redisData[i]?.[1]; if (text) { formatRedisPrompt.push(text); diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index ad0e1e841..6c3bd830d 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -126,7 +126,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 获取提示词的向量 const { vector: promptVector } = await openaiCreateEmbedding({ isPay: true, - apiKey: apiKey, + apiKey, userId, text: prompt.value }); diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts new file mode 100644 index 000000000..934131379 --- /dev/null +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -0,0 +1,210 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { connectToDatabase, Model } from '@/service/mongo'; +import { + httpsAgent, + openaiChatFilter, + systemPromptFilter, + authOpenApiKey +} from '@/service/utils/tools'; +import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; +import { ChatItemType } from '@/types/chat'; +import { jsonRes } from '@/service/response'; +import { PassThrough } from 'stream'; +import { modelList } from '@/constants/model'; +import { pushChatBill } from '@/service/events/pushBill'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataPrefix } from '@/constants/redis'; +import { vectorToBuffer } from '@/utils/tools'; +import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; + +/* 发送提示词 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); + + try { + const { + prompts, + modelId, + isStream = true + } = req.body as { + prompts: ChatItemType[]; + modelId: string; + isStream: boolean; + }; + + if (!prompts || !modelId) { + throw new Error('缺少参数'); + } + if (!Array.isArray(prompts)) { + throw new Error('prompts is not array'); + } + if (prompts.length > 30 || prompts.length === 0) { + throw new Error('prompts length range 1-30'); + } + + await connectToDatabase(); + const redis = await connectRedis(); + let startTime = Date.now(); + + /* 凭证校验 */ + const { apiKey, userId } = await authOpenApiKey(req); + + const model = await Model.findOne({ + _id: modelId, + userId + }); + + if (!model) { + throw new Error('无权使用该模型'); + } + + const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName); + if (!modelConstantsData) { + throw new Error('模型初始化异常'); + } + + // 获取提示词的向量 + const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ + isPay: true, + apiKey, + userId, + text: prompts[prompts.length - 1].value // 取最后一个 + }); + + // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text + const redisData: any[] = await redis.sendCommand([ + 'FT.SEARCH', + `idx:${VecModelDataPrefix}:hash`, + `@modelId:{${modelId}} @vector:[VECTOR_RANGE 0.24 $blob]=>{$YIELD_DISTANCE_AS: score}`, + 'RETURN', + '1', + 'text', + 'SORTBY', + 'score', + 'PARAMS', + '2', + 'blob', + vectorToBuffer(promptVector), + 'LIMIT', + '0', + '30', + 'DIALECT', + '2' + ]); + + const formatRedisPrompt: string[] = []; + + // 格式化响应值,获取 qa + for (let i = 2; i < 61; i += 2) { + const text = redisData[i]?.[1]; + if (text) { + formatRedisPrompt.push(text); + } + } + + if (formatRedisPrompt.length === 0) { + throw new Error('对不起,我没有找到你的问题'); + } + + // system 合并 + if (prompts[0].obj === 'SYSTEM') { + formatRedisPrompt.unshift(prompts.shift()?.value || ''); + } + + // textArr 筛选,最多 2800 tokens + const systemPrompt = systemPromptFilter(formatRedisPrompt, 2800); + + prompts.unshift({ + obj: 'SYSTEM', + value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"` + }); + + // 控制在 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + + // 格式化文本内容成 chatgpt 格式 + const map = { + Human: ChatCompletionRequestMessageRoleEnum.User, + AI: ChatCompletionRequestMessageRoleEnum.Assistant, + SYSTEM: ChatCompletionRequestMessageRoleEnum.System + }; + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); + // console.log(formatPrompts); + // 计算温度 + const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); + + // 发出请求 + const chatResponse = await chatAPI.createChatCompletion( + { + model: model.service.chatModel, + temperature: temperature, + messages: formatPrompts, + frequency_penalty: 0.5, // 越大,重复内容越少 + presence_penalty: -0.5, // 越大,越容易出现新内容 + stream: isStream + }, + { + timeout: 120000, + responseType: isStream ? 'stream' : 'json', + httpsAgent + } + ); + + console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); + + step = 1; + let responseContent = ''; + + if (isStream) { + const streamResponse = await gpt35StreamResponse({ + res, + stream, + chatResponse + }); + responseContent = streamResponse.responseContent; + } else { + responseContent = chatResponse.data.choices?.[0]?.message?.content || ''; + jsonRes(res, { + data: responseContent + }); + } + + const promptsContent = formatPrompts.map((item) => item.content).join(''); + pushChatBill({ + isPay: true, + modelName: model.service.modelName, + userId, + text: promptsContent + responseContent + }); + // jsonRes(res); + } catch (err: any) { + if (step === 1) { + // 直接结束流 + console.log('error,结束'); + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } + } +} diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index ffa171aee..8ae7a6286 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -300,10 +300,9 @@ const Chat = ({ chatId }: { chatId: string }) => { // 复制内容 const onclickCopy = useCallback( - (chatId: string) => { - const dom = document.getElementById(chatId); - const innerText = dom?.innerText; - innerText && copyData(innerText); + (value: string) => { + const val = value.replace(/\n+/g, '\n'); + copyData(val); }, [copyData] ); @@ -434,7 +433,7 @@ const Chat = ({ chatId }: { chatId: string }) => { /> - onclickCopy(`chat${index}`)}>复制 + onclickCopy(item.value)}>复制 delChatRecord(index)}>删除该行 diff --git a/src/service/errorCode.ts b/src/service/errorCode.ts index aba2cb5d8..73d087ff2 100644 --- a/src/service/errorCode.ts +++ b/src/service/errorCode.ts @@ -7,7 +7,8 @@ export const openaiError: Record = { 'Bad Gateway': '网关异常,请重试' }; export const openaiError2: Record = { - insufficient_quota: 'API 余额不足' + insufficient_quota: 'API 余额不足', + invalid_request_error: '输入参数异常' }; export const proxyError: Record = { ECONNABORTED: true, diff --git a/src/service/response.ts b/src/service/response.ts index f1b8ca27e..4897f89de 100644 --- a/src/service/response.ts +++ b/src/service/response.ts @@ -25,8 +25,11 @@ export const jsonRes = ( msg = error; } else if (proxyError[error?.code]) { msg = '服务器代理出错'; - } else if (openaiError2[error?.response?.data?.error?.type]) { - msg = openaiError2[error?.response?.data?.error?.type]; + } else if (error?.response?.data?.error) { + msg = + openaiError2[error?.response?.data?.error?.type] || + error?.response?.data?.error?.message || + 'openai 错误'; } else if (openaiError[error?.response?.statusText]) { msg = openaiError[error.response.statusText]; }