diff --git a/src/constants/model.ts b/src/constants/model.ts index 61fd8b9a4..c5a87011c 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -4,7 +4,8 @@ import type { RedisModelDataItemType } from '@/types/redis'; export enum ChatModelNameEnum { GPT35 = 'gpt-3.5-turbo', VECTOR_GPT = 'VECTOR_GPT', - GPT3 = 'text-davinci-003' + GPT3 = 'text-davinci-003', + VECTOR = 'text-embedding-ada-002' } export const ChatModelNameMap = { diff --git a/src/constants/user.ts b/src/constants/user.ts index 329adeb10..11e039135 100644 --- a/src/constants/user.ts +++ b/src/constants/user.ts @@ -3,6 +3,7 @@ export enum BillTypeEnum { splitData = 'splitData', QA = 'QA', abstract = 'abstract', + vector = 'vector', return = 'return' } export enum PageTypeEnum { @@ -16,5 +17,6 @@ export const BillTypeMap: Record<`${BillTypeEnum}`, string> = { [BillTypeEnum.splitData]: 'QA拆分', [BillTypeEnum.QA]: 'QA拆分', [BillTypeEnum.abstract]: '摘要总结', + [BillTypeEnum.vector]: '索引生成', [BillTypeEnum.return]: '退款' }; diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 387baeee0..ed4a1689e 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -13,6 +13,7 @@ import { pushChatBill } from '@/service/events/pushBill'; import { connectRedis } from '@/service/redis'; import { VecModelDataPrefix } from '@/constants/redis'; import { vectorToBuffer } from '@/utils/tools'; +import { openaiCreateEmbedding } from '@/service/utils/openai'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -57,21 +58,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const prompts = [...chat.content, prompt]; // 获取 chatAPI - const chatAPI = getOpenAIApi(userApiKey || systemKey); - - // 把输入的内容转成向量 - const promptVector = await chatAPI - .createEmbedding( - { - model: 'text-embedding-ada-002', - input: prompt.value - }, - { - timeout: 120000, - httpsAgent - } - ) - .then((res) => res?.data?.data?.[0]?.embedding || []); + const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ + isPay: !userApiKey, + apiKey: userApiKey || systemKey, + userId, + text: prompt.value + }); // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text const redisData: any[] = await redis.sendCommand([ @@ -79,7 +71,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) `idx:${VecModelDataPrefix}:hash`, `@modelId:{${String( chat.modelId._id - )}} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}`, + )}} @vector:[VECTOR_RANGE 0.2 $blob]=>{$YIELD_DISTANCE_AS: score}`, // `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`, 'RETURN', '1', diff --git a/src/service/events/generateAbstract.ts b/src/service/events/generateAbstract.ts index a6b50c6b8..d53b4f762 100644 --- a/src/service/events/generateAbstract.ts +++ b/src/service/events/generateAbstract.ts @@ -84,36 +84,6 @@ export async function generateAbstract(next = false): Promise { const rawContent: string = abstractResponse?.data.choices[0].message?.content || ''; // 从 content 中提取摘要内容 const splitContents = splitText(rawContent); - // console.log(rawContent); - // 生成词向量 - // const vectorResponse = await Promise.allSettled( - // splitContents.map((item) => - // chatAPI.createEmbedding( - // { - // model: 'text-embedding-ada-002', - // input: item.abstract - // }, - // { - // timeout: 120000, - // httpsAgent - // } - // ) - // ) - // ); - // 筛选成功的向量请求 - // const vectorSuccessResponse = vectorResponse - // .map((item: any, i) => { - // if (item.status !== 'fulfilled') { - // // 没有词向量的【摘要】不要 - // console.log('获取词向量错误: ', item); - // return ''; - // } - // return { - // abstract: splitContents[i].abstract, - // abstractVector: item?.value?.data?.data?.[0]?.embedding - // }; - // }) - // .filter((item) => item); // 插入数据库,并修改状态 await DataItem.findByIdAndUpdate(dataItem._id, { diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index 3ec568de8..71153ac29 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -83,7 +83,7 @@ export async function generateQA(next = false): Promise { ] }, { - timeout: 120000, + timeout: 180000, httpsAgent } ) diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 806a12382..c3343cda3 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -4,6 +4,7 @@ import { connectRedis } from '../redis'; import { VecModelDataIdx } from '@/constants/redis'; import { vectorToBuffer } from '@/utils/tools'; import { ModelDataStatusEnum } from '@/constants/redis'; +import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai'; export async function generateVector(next = false): Promise { if (global.generatingVector && !next) return; @@ -17,7 +18,7 @@ export async function generateVector(next = false): Promise { VecModelDataIdx, `@status:{${ModelDataStatusEnum.waiting}}`, { - RETURN: ['q'], + RETURN: ['q', 'userId'], LIMIT: { from: 0, size: 1 @@ -31,30 +32,22 @@ export async function generateVector(next = false): Promise { return; } - const dataItem: { id: string; q: string } = { + const dataItem: { id: string; q: string; userId: string } = { id: searchRes.documents[0].id, - q: String(searchRes.documents[0]?.value?.q || '') + q: String(searchRes.documents[0]?.value?.q || ''), + userId: String(searchRes.documents[0]?.value?.userId || '') }; // 获取 openapi Key - const openAiKey = process.env.OPENAIKEY as string; - - // 获取 openai 请求实例 - const chatAPI = getOpenAIApi(openAiKey); + const { userApiKey, systemKey } = await getOpenApiKey(dataItem.userId); // 生成词向量 - const vector = await chatAPI - .createEmbedding( - { - model: 'text-embedding-ada-002', - input: dataItem.q - }, - { - timeout: 120000, - httpsAgent - } - ) - .then((res) => res?.data?.data?.[0]?.embedding || []); + const { vector } = await openaiCreateEmbedding({ + text: dataItem.q, + userId: dataItem.userId, + isPay: !userApiKey, + apiKey: userApiKey || systemKey + }); // 更新 redis 向量和状态数据 await redis.sendCommand([ diff --git a/src/service/events/pushBill.ts b/src/service/events/pushBill.ts index 089b61299..269fa50ed 100644 --- a/src/service/events/pushBill.ts +++ b/src/service/events/pushBill.ts @@ -2,6 +2,7 @@ import { connectToDatabase, Bill, User } from '../mongo'; import { modelList, ChatModelNameEnum } from '@/constants/model'; import { encode } from 'gpt-token-utils'; import { formatPrice } from '@/utils/user'; +import { BillTypeEnum } from '@/constants/user'; import type { DataType } from '@/types/data'; export const pushChatBill = async ({ @@ -23,8 +24,7 @@ export const pushChatBill = async ({ // 计算 token 数量 const tokens = encode(text); - console.log('text len: ', text.length); - console.log('token len:', tokens.length); + console.log(`chat generate success. text len: ${text.length}. token len: ${tokens.length}`); if (isPay) { await connectToDatabase(); @@ -34,7 +34,7 @@ export const pushChatBill = async ({ // 计算价格 const unitPrice = modelItem?.price || 5; const price = unitPrice * tokens.length; - console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}元`); + console.log(`unit price: ${unitPrice}, price: ${formatPrice(price)}元`); try { // 插入 Bill 记录 @@ -82,8 +82,9 @@ export const pushSplitDataBill = async ({ // 计算 token 数量 const tokens = encode(text); - console.log('text len: ', text.length); - console.log('token len:', tokens.length); + console.log( + `splitData generate success. text len: ${text.length}. token len: ${tokens.length}` + ); if (isPay) { try { @@ -93,7 +94,7 @@ export const pushSplitDataBill = async ({ // 计算价格 const price = unitPrice * tokens.length; - console.log(`splitData bill, price: ${formatPrice(price)}元`); + console.log(`price: ${formatPrice(price)}元`); // 插入 Bill 记录 const res = await Bill.create({ @@ -123,13 +124,11 @@ export const pushSplitDataBill = async ({ export const pushGenerateVectorBill = async ({ isPay, userId, - text, - type + text }: { isPay: boolean; userId: string; text: string; - type: DataType; }) => { await connectToDatabase(); @@ -139,24 +138,21 @@ export const pushGenerateVectorBill = async ({ // 计算 token 数量 const tokens = encode(text); - console.log('text len: ', text.length); - console.log('token len:', tokens.length); + console.log(`vector generate success. text len: ${text.length}. token len: ${tokens.length}`); if (isPay) { try { - // 获取模型单价格, 都是用 gpt35 拆分 - const modelItem = modelList.find((item) => item.model === ChatModelNameEnum.GPT35); - const unitPrice = modelItem?.price || 5; + const unitPrice = 1; // 计算价格 const price = unitPrice * tokens.length; - console.log(`splitData bill, price: ${formatPrice(price)}元`); + console.log(`price: ${formatPrice(price)}元`); // 插入 Bill 记录 const res = await Bill.create({ userId, - type, - modelName: ChatModelNameEnum.GPT35, + type: BillTypeEnum.vector, + modelName: ChatModelNameEnum.VECTOR, textLen: text.length, tokenLen: tokens.length, price diff --git a/src/service/models/bill.ts b/src/service/models/bill.ts index 82b57f840..25ecda4a6 100644 --- a/src/service/models/bill.ts +++ b/src/service/models/bill.ts @@ -16,7 +16,7 @@ const BillSchema = new Schema({ }, modelName: { type: String, - enum: modelList.map((item) => item.model), + enum: [...modelList.map((item) => item.model), 'text-embedding-ada-002'], required: true }, chatId: { diff --git a/src/service/utils/openai.ts b/src/service/utils/openai.ts index f373bc965..26b3df39d 100644 --- a/src/service/utils/openai.ts +++ b/src/service/utils/openai.ts @@ -3,6 +3,8 @@ import { getOpenAIApi } from '@/service/utils/chat'; import { httpsAgent } from './tools'; import { User } from '../models/user'; import { formatPrice } from '@/utils/user'; +import { ChatModelNameEnum } from '@/constants/model'; +import { pushGenerateVectorBill } from '../events/pushBill'; /* 判断 apikey 是否还有余额 */ export const checkKeyGrant = async (apiKey: string) => { @@ -87,3 +89,44 @@ export const getOpenApiKey = async (userId: string, checkGrant = false) => { systemKey: process.env.OPENAIKEY as string }; }; + +/* 获取向量 */ +export const openaiCreateEmbedding = async ({ + isPay, + userId, + apiKey, + text +}: { + isPay: boolean; + userId: string; + apiKey: string; + text: string; +}) => { + // 获取 chatAPI + const chatAPI = getOpenAIApi(apiKey); + + // 把输入的内容转成向量 + const vector = await chatAPI + .createEmbedding( + { + model: ChatModelNameEnum.VECTOR, + input: text + }, + { + timeout: 60000, + httpsAgent + } + ) + .then((res) => res?.data?.data?.[0]?.embedding || []); + + pushGenerateVectorBill({ + isPay, + userId, + text + }); + + return { + vector, + chatAPI + }; +};