perf: api调用和余额校验

This commit is contained in:
archer
2023-03-31 11:20:45 +08:00
parent ed1f93d836
commit 56dab7abba
10 changed files with 104 additions and 62 deletions

View File

@@ -45,6 +45,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
} }
}); });
if (splitText) {
textList.push(splitText);
}
// 批量插入数据 // 批量插入数据
await SplitData.create({ await SplitData.create({
userId, userId,
@@ -55,9 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
generateQA(); generateQA();
jsonRes(res, { jsonRes(res);
data: { chunks, replaceText }
});
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {
code: 500, code: 500,

View File

@@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model'; import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { TrainingItemType } from '@/types/training'; import { TrainingItemType } from '@/types/training';
@@ -67,7 +67,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 如果正在训练需要删除openai上的相关信息 // 如果正在训练需要删除openai上的相关信息
if (training) { if (training) {
const openai = getOpenAIApi(await getUserOpenaiKey(userId)); const { openai } = await getUserApiOpenai(userId);
// 获取训练记录 // 获取训练记录
const tuneRecord = await openai.retrieveFineTune(training.tuneId, { httpsAgent }); const tuneRecord = await openai.retrieveFineTune(training.tuneId, { httpsAgent });

View File

@@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, Model, Training } from '@/service/mongo'; import { connectToDatabase, Model, Training } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { TrainingItemType } from '@/types/training'; import { TrainingItemType } from '@/types/training';
import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model'; import { ModelStatusEnum, TrainingStatusEnum } from '@/constants/model';
@@ -43,7 +43,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
} }
// 用户的 openai 实例 // 用户的 openai 实例
const openai = getOpenAIApi(await getUserOpenaiKey(userId)); const { openai } = await getUserApiOpenai(userId);
// 获取 openai 的训练情况 // 获取 openai 的训练情况
const { data } = await openai.retrieveFineTune(training.tuneId, { httpsAgent }); const { data } = await openai.retrieveFineTune(training.tuneId, { httpsAgent });

View File

@@ -2,9 +2,8 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, Model, Training } from '@/service/mongo'; import { connectToDatabase, Model, Training } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import formidable from 'formidable'; import formidable from 'formidable';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { join } from 'path'; import { join } from 'path';
import fs from 'fs'; import fs from 'fs';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
@@ -49,7 +48,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const trainingType = model.service.trainId; // 目前都默认是 openai text-davinci-03 const trainingType = model.service.trainId; // 目前都默认是 openai text-davinci-03
// 获取用户的 API Key 实例化后的对象 // 获取用户的 API Key 实例化后的对象
openai = getOpenAIApi(await getUserOpenaiKey(userId)); const user = await getUserApiOpenai(userId);
openai = user.openai;
// 接收文件并保存 // 接收文件并保存
const form = formidable({ const form = formidable({

View File

@@ -5,7 +5,7 @@ import { connectToDatabase, Training, Model } from '@/service/mongo';
import type { TrainingItemType } from '@/types/training'; import type { TrainingItemType } from '@/types/training';
import { TrainingStatusEnum, ModelStatusEnum } from '@/constants/model'; import { TrainingStatusEnum, ModelStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { getUserOpenaiKey } from '@/service/utils/tools'; import { getUserApiOpenai } from '@/service/utils/tools';
import { OpenAiTuneStatusEnum } from '@/service/constants/training'; import { OpenAiTuneStatusEnum } from '@/service/constants/training';
import { sendTrainSucceed } from '@/service/utils/sendEmail'; import { sendTrainSucceed } from '@/service/utils/sendEmail';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
@@ -23,7 +23,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
status: TrainingStatusEnum.pending status: TrainingStatusEnum.pending
}); });
const openai = getOpenAIApi(await getUserOpenaiKey('63f9a14228d2a688d8dc9e1b')); const { openai } = await getUserApiOpenai('63f9a14228d2a688d8dc9e1b');
const response = await Promise.all( const response = await Promise.all(
trainingRecords.map(async (item) => { trainingRecords.map(async (item) => {

View File

@@ -43,12 +43,15 @@ export async function generateAbstract(next = false): Promise<any> {
const key = await getOpenApiKey(dataItem.userId); const key = await getOpenApiKey(dataItem.userId);
userApiKey = key.userApiKey; userApiKey = key.userApiKey;
systemKey = key.systemKey; systemKey = key.systemKey;
} catch (error) { } catch (error: any) {
// 余额不够了, 把用户所有记录改成闲置 if (error?.code === 501) {
await DataItem.updateMany({ // 余额不够了, 把用户所有记录改成闲置
userId: dataItem.userId, await DataItem.updateMany({
status: 0 userId: dataItem.userId,
}); status: 0
});
}
throw new Error('获取 openai key 失败'); throw new Error('获取 openai key 失败');
} }

View File

@@ -40,12 +40,15 @@ export async function generateQA(next = false): Promise<any> {
const key = await getOpenApiKey(dataItem.userId); const key = await getOpenApiKey(dataItem.userId);
userApiKey = key.userApiKey; userApiKey = key.userApiKey;
systemKey = key.systemKey; systemKey = key.systemKey;
} catch (error) { } catch (error: any) {
// 余额不够了, 清空该记录 if (error?.code === 501) {
await SplitData.findByIdAndUpdate(dataItem._id, { // 余额不够了, 清空该记录
textList: [], await SplitData.findByIdAndUpdate(dataItem._id, {
errorText: '余额不足,生成数据集任务终止' textList: [],
}); errorText: error.message
});
}
throw new Error('获取 openai key 失败'); throw new Error('获取 openai key 失败');
} }
@@ -121,7 +124,7 @@ export async function generateQA(next = false): Promise<any> {
setTimeout(() => { setTimeout(() => {
generateQA(true); generateQA(true);
}, 10000); }, 5000);
} }
} }

View File

@@ -85,5 +85,9 @@ export async function generateVector(next = false): Promise<any> {
generateVector(true); generateVector(true);
}, 60000); }, 60000);
} }
setTimeout(() => {
generateVector(true);
}, 3000);
} }
} }

View File

@@ -1,8 +1,7 @@
import { Configuration, OpenAIApi } from 'openai'; import { Configuration, OpenAIApi } from 'openai';
import { Chat } from '../mongo'; import { Chat } from '../mongo';
import type { ChatPopulate } from '@/types/mongoSchema'; import type { ChatPopulate } from '@/types/mongoSchema';
import { formatPrice } from '@/utils/user'; import { authToken, getOpenApiKey } from './tools';
import { authToken } from './tools';
export const getOpenAIApi = (apiKey: string) => { export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({ const configuration = new Configuration({
@@ -14,19 +13,12 @@ export const getOpenAIApi = (apiKey: string) => {
export const authChat = async (chatId: string, authorization?: string) => { export const authChat = async (chatId: string, authorization?: string) => {
// 获取 chat 数据 // 获取 chat 数据
const chat = await Chat.findById<ChatPopulate>(chatId) const chat = await Chat.findById<ChatPopulate>(chatId).populate({
.populate({ path: 'modelId',
path: 'modelId', options: {
options: { strictPopulate: false
strictPopulate: false }
} });
})
.populate({
path: 'userId',
options: {
strictPopulate: false
}
});
if (!chat || !chat.modelId || !chat.userId) { if (!chat || !chat.modelId || !chat.userId) {
return Promise.reject('模型不存在'); return Promise.reject('模型不存在');
@@ -43,21 +35,14 @@ export const authChat = async (chatId: string, authorization?: string) => {
} }
// 获取 user 的 apiKey // 获取 user 的 apiKey
const user = chat.userId; const { user, userApiKey, systemKey } = await getOpenApiKey(chat.userId as unknown as string);
const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value;
// 没有 apikey ,校验余额
if (!userApiKey && formatPrice(user.balance) <= 0) {
return Promise.reject('该账号余额不足');
}
// filter 掉被 deleted 的内容 // filter 掉被 deleted 的内容
chat.content = chat.content.filter((item) => item.deleted !== true); chat.content = chat.content.filter((item) => item.deleted !== true);
return { return {
userApiKey, userApiKey,
systemKey: process.env.OPENAIKEY as string, systemKey,
chat, chat,
userId: user._id userId: user._id
}; };

View File

@@ -2,10 +2,12 @@ import crypto from 'crypto';
import jwt from 'jsonwebtoken'; import jwt from 'jsonwebtoken';
import { User } from '../models/user'; import { User } from '../models/user';
import tunnel from 'tunnel'; import tunnel from 'tunnel';
import type { UserModelSchema } from '@/types/mongoSchema';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { encode } from 'gpt-token-utils'; import { encode } from 'gpt-token-utils';
import { getOpenAIApi } from '@/service/utils/chat';
import axios from 'axios';
import { UserModelSchema } from '@/types/mongoSchema';
/* 密码加密 */ /* 密码加密 */
export const hashPassword = (psw: string) => { export const hashPassword = (psw: string) => {
@@ -44,40 +46,83 @@ export const authToken = (token?: string): Promise<string> => {
}); });
}; };
/* 获取用户的 openai APIkey */ /* 判断 apikey 是否还有余额 */
export const getUserOpenaiKey = async (userId: string) => { export const checkKeyGrant = async (apiKey: string) => {
const grant = await axios.get('https://api.openai.com/dashboard/billing/credit_grants', {
headers: {
Authorization: `Bearer ${apiKey}`
}
});
console.log(grant.data?.total_available);
if (grant.data?.total_available <= 0.2) {
return false;
}
return true;
};
/* 获取用户 api 的 openai 信息 */
export const getUserApiOpenai = async (userId: string) => {
const user = await User.findById(userId); const user = await User.findById(userId);
const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value; const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
if (!userApiKey) { if (!userApiKey) {
return Promise.reject('缺少ApiKey, 无法请求'); return Promise.reject('缺少ApiKey, 无法请求');
} }
return Promise.resolve(userApiKey as string); // 余额校验
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
return {
user,
openai: getOpenAIApi(userApiKey),
apiKey: userApiKey
};
}; };
/* 获取key如果没有就用平台的,用平台记得加账单 */ /* 获取 open api key如果用户没有自己的key就用平台的用平台记得加账单 */
export const getOpenApiKey = async (userId: string) => { export const getOpenApiKey = async (userId: string) => {
const user = await User.findById<UserModelSchema>(userId); const user = await User.findById(userId);
if (!user) {
return Promise.reject('找不到用户');
}
if (!user) return Promise.reject('用户不存在'); const userApiKey = user?.accounts?.find((item: any) => item.type === 'openai')?.value;
const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value; // 有自己的key
// 有自己的key 直接使用
if (userApiKey) { if (userApiKey) {
// api 余额校验
const hasGrant = await checkKeyGrant(userApiKey);
if (!hasGrant) {
return Promise.reject({
code: 501,
message: 'API 余额不足'
});
}
return { return {
userApiKey: await getUserOpenaiKey(userId), user,
userApiKey,
systemKey: '' systemKey: ''
}; };
} }
// 余额校验 // 平台账号余额校验
if (formatPrice(user.balance) <= 0) { if (formatPrice(user.balance) <= 0) {
return Promise.reject('该账号余额不足'); return Promise.reject({
code: 501,
message: '账号余额不足'
});
} }
return { return {
user,
userApiKey: '', userApiKey: '',
systemKey: process.env.OPENAIKEY as string systemKey: process.env.OPENAIKEY as string
}; };