From 05b2e9e99c0e57a709bf2482ec77a96296ba5926 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sun, 2 Apr 2023 23:38:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8B=86=E5=88=86=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=8E=AF=E5=A2=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/model.ts | 5 ++ src/pages/api/model/data/exportModelData.ts | 61 +++++++++++++++++++ src/pages/login/index.tsx | 1 - .../model/detail/components/ModelDataCard.tsx | 40 ++++++++++-- src/service/events/generateAbstract.ts | 2 - src/service/events/pushBill.ts | 57 +++++++++++++++++ src/service/mongo.ts | 2 +- src/service/redis.ts | 2 +- src/utils/tools.ts | 19 +++++- 9 files changed, 177 insertions(+), 12 deletions(-) create mode 100644 src/pages/api/model/data/exportModelData.ts diff --git a/src/api/model.ts b/src/api/model.ts index ff11ce8f4..34de39102 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -38,6 +38,11 @@ type GetModelDataListProps = RequestPaging & { export const getModelDataList = (props: GetModelDataListProps) => GET(`/model/data/getModelData?${Obj2Query(props)}`); +export const getExportDataList = (modelId: string) => + GET<{ prompt: string; completion: string; vector: number[] }>( + `/model/data/exportModelData?modelId=${modelId}` + ); + export const getModelSplitDataList = (modelId: string) => GET(`/model/data/getSplitData?modelId=${modelId}`); diff --git a/src/pages/api/model/data/exportModelData.ts b/src/pages/api/model/data/exportModelData.ts new file mode 100644 index 000000000..32782a83e --- /dev/null +++ b/src/pages/api/model/data/exportModelData.ts @@ -0,0 +1,61 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataIdx } from '@/constants/redis'; +import { BufferToVector } from '@/utils/tools'; + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + let { modelId } = req.query as { + modelId: string; + }; + + const { authorization } = req.headers; + + if (!authorization) { + throw new Error('无权操作'); + } + + if (!modelId) { + throw new Error('缺少参数'); + } + + // 凭证校验 + const userId = await authToken(authorization); + + await connectToDatabase(); + const redis = await connectRedis(); + + // 从 redis 中获取数据 + const searchRes = await redis.ft.search( + VecModelDataIdx, + `@modelId:{${modelId}} @userId:{${userId}}`, + { + RETURN: ['q', 'text', 'vector'], + LIMIT: { + from: 0, + size: 10000 + } + } + ); + + const data = searchRes.documents + .filter((item) => item?.value?.vector) + .map((item: any) => ({ + prompt: item.value.q, + completion: item.value.text, + vector: BufferToVector(item.value.vector) + })); + + jsonRes(res, { + data + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/login/index.tsx b/src/pages/login/index.tsx index 22e25cc34..93bbc960b 100644 --- a/src/pages/login/index.tsx +++ b/src/pages/login/index.tsx @@ -71,7 +71,6 @@ const Login = () => { order={1} flex={`0 0 ${isPc ? '400px' : '100%'}`} height={'100%'} - maxH={'450px'} border="1px" borderColor="gray.200" py={5} diff --git a/src/pages/model/detail/components/ModelDataCard.tsx b/src/pages/model/detail/components/ModelDataCard.tsx index 0718384ed..a2e1b7bb1 100644 --- a/src/pages/model/detail/components/ModelDataCard.tsx +++ b/src/pages/model/detail/components/ModelDataCard.tsx @@ -26,13 +26,14 @@ import { getModelDataList, delOneModelData, putModelDataById, - getModelSplitDataList + getModelSplitDataList, + getExportDataList } from '@/api/model'; import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons'; import { useToast } from '@/hooks/useToast'; import { useLoading } from '@/hooks/useLoading'; import dynamic from 'next/dynamic'; -import { useQuery } from '@tanstack/react-query'; +import { useMutation, useQuery } from '@tanstack/react-query'; const InputModel = dynamic(() => import('./InputDataModal')); const SelectFileModel = dynamic(() => import('./SelectFileModal')); @@ -99,10 +100,29 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => { [getData, refetch] ); + // 获取所有的数据,并导出 json + const { mutate: onclickExport, isLoading: isLoadingExport } = useMutation({ + mutationFn: () => getExportDataList(model._id), + onSuccess(res) { + // 导出为文件 + const blob = new Blob([JSON.stringify(res)], { type: 'application/json;charset=utf-8' }); + + // 创建下载链接 + const downloadLink = document.createElement('a'); + downloadLink.href = window.URL.createObjectURL(blob); + downloadLink.download = `data.json`; + + // 添加链接到页面并触发下载 + document.body.appendChild(downloadLink); + downloadLink.click(); + document.body.removeChild(downloadLink); + } + }); + return ( <> - + 模型数据: {total}组{' '} (测试版本) @@ -113,10 +133,22 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => { aria-label={'refresh'} variant={'outline'} mr={4} + size={'sm'} onClick={() => refetchData(pageNum)} /> + {/* */} - 导入 + + 导入 + 手动输入 文件导入 diff --git a/src/service/events/generateAbstract.ts b/src/service/events/generateAbstract.ts index 1314a10a6..4039a2527 100644 --- a/src/service/events/generateAbstract.ts +++ b/src/service/events/generateAbstract.ts @@ -7,8 +7,6 @@ import { ChatModelNameEnum } from '@/constants/model'; import { pushSplitDataBill } from '@/service/events/pushBill'; export async function generateAbstract(next = false): Promise { - if (process.env.NODE_ENV === 'development') return; - if (global.generatingAbstract && !next) return; global.generatingAbstract = true; diff --git a/src/service/events/pushBill.ts b/src/service/events/pushBill.ts index f9ad5e50e..089b61299 100644 --- a/src/service/events/pushBill.ts +++ b/src/service/events/pushBill.ts @@ -119,3 +119,60 @@ export const pushSplitDataBill = async ({ console.log(error); } }; + +export const pushGenerateVectorBill = async ({ + isPay, + userId, + text, + type +}: { + isPay: boolean; + userId: string; + text: string; + type: DataType; +}) => { + await connectToDatabase(); + + let billId; + + try { + // 计算 token 数量 + const tokens = encode(text); + + console.log('text len: ', text.length); + console.log('token len:', tokens.length); + + if (isPay) { + try { + // 获取模型单价格, 都是用 gpt35 拆分 + const modelItem = modelList.find((item) => item.model === ChatModelNameEnum.GPT35); + const unitPrice = modelItem?.price || 5; + // 计算价格 + const price = unitPrice * tokens.length; + + console.log(`splitData bill, price: ${formatPrice(price)}元`); + + // 插入 Bill 记录 + const res = await Bill.create({ + userId, + type, + modelName: ChatModelNameEnum.GPT35, + textLen: text.length, + tokenLen: tokens.length, + price + }); + billId = res._id; + + // 账号扣费 + await User.findByIdAndUpdate(userId, { + $inc: { balance: -price } + }); + } catch (error) { + console.log('创建账单失败:', error); + billId && Bill.findByIdAndDelete(billId); + } + } + } catch (error) { + console.log(error); + } +}; diff --git a/src/service/mongo.ts b/src/service/mongo.ts index fda7f46f2..f2bf070ae 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -17,7 +17,7 @@ export async function connectToDatabase(): Promise { mongoose.set('strictQuery', true); global.mongodb = await mongoose.connect(process.env.MONGODB_URI as string, { bufferCommands: true, - dbName: 'doc_gpt', + dbName: process.env.NODE_ENV === 'development' ? 'doc_gpt_test' : 'doc_gpt', maxPoolSize: 5, minPoolSize: 1, maxConnecting: 5 diff --git a/src/service/redis.ts b/src/service/redis.ts index 9a1c6c783..8b79957be 100644 --- a/src/service/redis.ts +++ b/src/service/redis.ts @@ -30,7 +30,7 @@ export const connectRedis = async () => { await global.redisClient.connect(); // 1 - 测试库,0 - 正式 - await global.redisClient.select(process.env.NODE_ENV === 'development' ? 0 : 0); + await global.redisClient.SELECT(0); return global.redisClient; } catch (error) { diff --git a/src/utils/tools.ts b/src/utils/tools.ts index 254236769..0e8229b6e 100644 --- a/src/utils/tools.ts +++ b/src/utils/tools.ts @@ -123,13 +123,26 @@ export const readDocContent = (file: File) => }); export const vectorToBuffer = (vector: number[]) => { - let npVector = new Float32Array(vector); + const npVector = new Float32Array(vector); - return Buffer.from(npVector.buffer); + const buffer = Buffer.from(npVector.buffer); + + return buffer; +}; +export const BufferToVector = (bufferStr: string) => { + let buffer = Buffer.from(`bufferStr`, 'binary'); // 将字符串转换成 Buffer 对象 + const npVector = new Float32Array( + buffer, + buffer.byteOffset, + buffer.byteLength / Float32Array.BYTES_PER_ELEMENT + ); + return Array.from(npVector); }; export function formatVector(vector: number[]) { let formattedVector = vector.slice(0, 1536); // 截取前1536个元素 - formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0 + if (vector.length > 1536) { + formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0 + } return formattedVector; }