From 56dab7abba23dbdd34d1612194892e251bbbdc9f Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Fri, 31 Mar 2023 11:20:45 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20api=E8=B0=83=E7=94=A8=E5=92=8C=E4=BD=99?= =?UTF-8?q?=E9=A2=9D=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pages/api/model/data/splitData.ts | 8 ++- src/pages/api/model/del.ts | 4 +- src/pages/api/model/train/putTrainStatus.ts | 4 +- src/pages/api/model/train/train.ts | 6 +- src/pages/api/timer/updateTraining.ts | 4 +- src/service/events/generateAbstract.ts | 15 +++-- src/service/events/generateQA.ts | 17 +++-- src/service/events/generateVector.ts | 4 ++ src/service/utils/chat.ts | 33 +++------- src/service/utils/tools.ts | 71 +++++++++++++++++---- 10 files changed, 104 insertions(+), 62 deletions(-) diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts index 0f0e846ff..cb262db20 100644 --- a/src/pages/api/model/data/splitData.ts +++ b/src/pages/api/model/data/splitData.ts @@ -45,6 +45,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } }); + if (splitText) { + textList.push(splitText); + } + // 批量插入数据 await SplitData.create({ userId, @@ -55,9 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) generateQA(); - jsonRes(res, { - data: { chunks, replaceText } - }); + jsonRes(res); } catch (err) { jsonRes(res, { code: 500, diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index a79610237..2d6e1729f 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; -import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; +import { authToken, getUserApiOpenai } from '@/service/utils/tools'; import { TrainingStatusEnum } from '@/constants/model'; import { getOpenAIApi } from '@/service/utils/chat'; import { TrainingItemType } from '@/types/training'; @@ -67,7 +67,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 如果正在训练,需要删除openai上的相关信息 if (training) { - const openai = getOpenAIApi(await getUserOpenaiKey(userId)); + const { openai } = await getUserApiOpenai(userId); // 获取训练记录 const tuneRecord = await openai.retrieveFineTune(training.tuneId, { httpsAgent }); diff --git a/src/pages/api/model/train/putTrainStatus.ts b/src/pages/api/model/train/putTrainStatus.ts index b754705df..922bab885 100644 --- a/src/pages/api/model/train/putTrainStatus.ts +++ b/src/pages/api/model/train/putTrainStatus.ts @@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, Model, Training } from '@/service/mongo'; import { getOpenAIApi } from '@/service/utils/chat'; -import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; +import { authToken, getUserApiOpenai } from '@/service/utils/tools'; import type { ModelSchema } from '@/types/mongoSchema'; import { TrainingItemType } from '@/types/training'; import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model'; @@ -43,7 +43,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 用户的 openai 实例 - const openai = getOpenAIApi(await getUserOpenaiKey(userId)); + const { openai } = await getUserApiOpenai(userId); // 获取 openai 的训练情况 const { data } = await openai.retrieveFineTune(training.tuneId, { httpsAgent }); diff --git a/src/pages/api/model/train/train.ts b/src/pages/api/model/train/train.ts index 9efe71341..8e2e76cab 100644 --- a/src/pages/api/model/train/train.ts +++ b/src/pages/api/model/train/train.ts @@ -2,9 +2,8 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, Model, Training } from '@/service/mongo'; -import { getOpenAIApi } from '@/service/utils/chat'; import formidable from 'formidable'; -import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; +import { authToken, getUserApiOpenai } from '@/service/utils/tools'; import { join } from 'path'; import fs from 'fs'; import type { ModelSchema } from '@/types/mongoSchema'; @@ -49,7 +48,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const trainingType = model.service.trainId; // 目前都默认是 openai text-davinci-03 // 获取用户的 API Key 实例化后的对象 - openai = getOpenAIApi(await getUserOpenaiKey(userId)); + const user = await getUserApiOpenai(userId); + openai = user.openai; // 接收文件并保存 const form = formidable({ diff --git a/src/pages/api/timer/updateTraining.ts b/src/pages/api/timer/updateTraining.ts index 63f1c9535..dc418f704 100644 --- a/src/pages/api/timer/updateTraining.ts +++ b/src/pages/api/timer/updateTraining.ts @@ -5,7 +5,7 @@ import { connectToDatabase, Training, Model } from '@/service/mongo'; import type { TrainingItemType } from '@/types/training'; import { TrainingStatusEnum, ModelStatusEnum } from '@/constants/model'; import { getOpenAIApi } from '@/service/utils/chat'; -import { getUserOpenaiKey } from '@/service/utils/tools'; +import { getUserApiOpenai } from '@/service/utils/tools'; import { OpenAiTuneStatusEnum } from '@/service/constants/training'; import { sendTrainSucceed } from '@/service/utils/sendEmail'; import { httpsAgent } from '@/service/utils/tools'; @@ -23,7 +23,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) status: TrainingStatusEnum.pending }); - const openai = getOpenAIApi(await getUserOpenaiKey('63f9a14228d2a688d8dc9e1b')); + const { openai } = await getUserApiOpenai('63f9a14228d2a688d8dc9e1b'); const response = await Promise.all( trainingRecords.map(async (item) => { diff --git a/src/service/events/generateAbstract.ts b/src/service/events/generateAbstract.ts index c865152c8..61847063e 100644 --- a/src/service/events/generateAbstract.ts +++ b/src/service/events/generateAbstract.ts @@ -43,12 +43,15 @@ export async function generateAbstract(next = false): Promise { const key = await getOpenApiKey(dataItem.userId); userApiKey = key.userApiKey; systemKey = key.systemKey; - } catch (error) { - // 余额不够了, 把用户所有记录改成闲置 - await DataItem.updateMany({ - userId: dataItem.userId, - status: 0 - }); + } catch (error: any) { + if (error?.code === 501) { + // 余额不够了, 把用户所有记录改成闲置 + await DataItem.updateMany({ + userId: dataItem.userId, + status: 0 + }); + } + throw new Error('获取 openai key 失败'); } diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index 3f60d7594..7f38f362f 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -40,12 +40,15 @@ export async function generateQA(next = false): Promise { const key = await getOpenApiKey(dataItem.userId); userApiKey = key.userApiKey; systemKey = key.systemKey; - } catch (error) { - // 余额不够了, 清空该记录 - await SplitData.findByIdAndUpdate(dataItem._id, { - textList: [], - errorText: '余额不足,生成数据集任务终止' - }); + } catch (error: any) { + if (error?.code === 501) { + // 余额不够了, 清空该记录 + await SplitData.findByIdAndUpdate(dataItem._id, { + textList: [], + errorText: error.message + }); + } + throw new Error('获取 openai key 失败'); } @@ -121,7 +124,7 @@ export async function generateQA(next = false): Promise { setTimeout(() => { generateQA(true); - }, 10000); + }, 5000); } } diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 02d68b0d2..407429251 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -85,5 +85,9 @@ export async function generateVector(next = false): Promise { generateVector(true); }, 60000); } + + setTimeout(() => { + generateVector(true); + }, 3000); } } diff --git a/src/service/utils/chat.ts b/src/service/utils/chat.ts index 0cdf62f4c..3c0846333 100644 --- a/src/service/utils/chat.ts +++ b/src/service/utils/chat.ts @@ -1,8 +1,7 @@ import { Configuration, OpenAIApi } from 'openai'; import { Chat } from '../mongo'; import type { ChatPopulate } from '@/types/mongoSchema'; -import { formatPrice } from '@/utils/user'; -import { authToken } from './tools'; +import { authToken, getOpenApiKey } from './tools'; export const getOpenAIApi = (apiKey: string) => { const configuration = new Configuration({ @@ -14,19 +13,12 @@ export const getOpenAIApi = (apiKey: string) => { export const authChat = async (chatId: string, authorization?: string) => { // 获取 chat 数据 - const chat = await Chat.findById(chatId) - .populate({ - path: 'modelId', - options: { - strictPopulate: false - } - }) - .populate({ - path: 'userId', - options: { - strictPopulate: false - } - }); + const chat = await Chat.findById(chatId).populate({ + path: 'modelId', + options: { + strictPopulate: false + } + }); if (!chat || !chat.modelId || !chat.userId) { return Promise.reject('模型不存在'); @@ -43,21 +35,14 @@ export const authChat = async (chatId: string, authorization?: string) => { } // 获取 user 的 apiKey - const user = chat.userId; - - const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value; - - // 没有 apikey ,校验余额 - if (!userApiKey && formatPrice(user.balance) <= 0) { - return Promise.reject('该账号余额不足'); - } + const { user, userApiKey, systemKey } = await getOpenApiKey(chat.userId as unknown as string); // filter 掉被 deleted 的内容 chat.content = chat.content.filter((item) => item.deleted !== true); return { userApiKey, - systemKey: process.env.OPENAIKEY as string, + systemKey, chat, userId: user._id }; diff --git a/src/service/utils/tools.ts b/src/service/utils/tools.ts index 830b09402..268f4be13 100644 --- a/src/service/utils/tools.ts +++ b/src/service/utils/tools.ts @@ -2,10 +2,12 @@ import crypto from 'crypto'; import jwt from 'jsonwebtoken'; import { User } from '../models/user'; import tunnel from 'tunnel'; -import type { UserModelSchema } from '@/types/mongoSchema'; import { formatPrice } from '@/utils/user'; import { ChatItemType } from '@/types/chat'; import { encode } from 'gpt-token-utils'; +import { getOpenAIApi } from '@/service/utils/chat'; +import axios from 'axios'; +import { UserModelSchema } from '@/types/mongoSchema'; /* 密码加密 */ export const hashPassword = (psw: string) => { @@ -44,40 +46,83 @@ export const authToken = (token?: string): Promise => { }); }; -/* 获取用户的 openai APIkey */ -export const getUserOpenaiKey = async (userId: string) => { +/* 判断 apikey 是否还有余额 */ +export const checkKeyGrant = async (apiKey: string) => { + const grant = await axios.get('https://api.openai.com/dashboard/billing/credit_grants', { + headers: { + Authorization: `Bearer ${apiKey}` + } + }); + console.log(grant.data?.total_available); + if (grant.data?.total_available <= 0.2) { + return false; + } + return true; +}; + +/* 获取用户 api 的 openai 信息 */ +export const getUserApiOpenai = async (userId: string) => { const user = await User.findById(userId); const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value; + if (!userApiKey) { return Promise.reject('缺少ApiKey, 无法请求'); } - return Promise.resolve(userApiKey as string); + // 余额校验 + const hasGrant = await checkKeyGrant(userApiKey); + if (!hasGrant) { + return Promise.reject({ + code: 501, + message: 'API 余额不足' + }); + } + + return { + user, + openai: getOpenAIApi(userApiKey), + apiKey: userApiKey + }; }; -/* 获取key,如果没有就用平台的,用平台记得加账单 */ +/* 获取 open api key,如果用户没有自己的key,就用平台的,用平台记得加账单 */ export const getOpenApiKey = async (userId: string) => { - const user = await User.findById(userId); + const user = await User.findById(userId); + if (!user) { + return Promise.reject('找不到用户'); + } - if (!user) return Promise.reject('用户不存在'); + const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value; - const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value; - - // 有自己的key, 直接使用 + // 有自己的key if (userApiKey) { + // api 余额校验 + const hasGrant = await checkKeyGrant(userApiKey); + if (!hasGrant) { + return Promise.reject({ + code: 501, + message: 'API 余额不足' + }); + } + return { - userApiKey: await getUserOpenaiKey(userId), + user, + userApiKey, systemKey: '' }; } - // 余额校验 + // 平台账号余额校验 if (formatPrice(user.balance) <= 0) { - return Promise.reject('该账号余额不足'); + return Promise.reject({ + code: 501, + message: '账号余额不足' + }); } return { + user, userApiKey: '', systemKey: process.env.OPENAIKEY as string };