diff --git a/src/pages/api/model/data/delModelDataById.ts b/src/pages/api/model/data/delModelDataById.ts index d3e04ec09..1d8f5b930 100644 --- a/src/pages/api/model/data/delModelDataById.ts +++ b/src/pages/api/model/data/delModelDataById.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { authToken } from '@/service/utils/tools'; -import { connectPg } from '@/service/pg'; +import { PgClient } from '@/service/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -21,8 +21,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); - const pg = await connectPg(); - await pg.query(`DELETE FROM modelData WHERE user_id = '${userId}' AND id = '${dataId}'`); + await PgClient.delete('modelData', { + where: [['user_id', userId], 'AND', ['id', dataId]] + }); jsonRes(res); } catch (err) { diff --git a/src/pages/api/model/data/getModelData.ts b/src/pages/api/model/data/getModelData.ts index 4ec0783ac..38b400fa7 100644 --- a/src/pages/api/model/data/getModelData.ts +++ b/src/pages/api/model/data/getModelData.ts @@ -2,7 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { connectPg } from '@/service/pg'; +import { PgClient } from '@/service/pg'; import type { PgModelDataItemType } from '@/types/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -35,21 +35,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const userId = await authToken(authorization); await connectToDatabase(); - const pg = await connectPg(); - const searchRes = await pg.query(`SELECT id, q, a, status - FROM modelData - WHERE user_id='${userId}' AND model_id='${modelId}' - ORDER BY id DESC - LIMIT ${pageSize} OFFSET ${pageSize * (pageNum - 1)} - `); + const searchRes = await PgClient.select('modelData', { + field: ['id', 'q', 'a', 'status'], + where: [['user_id', userId], 'AND', ['model_id', modelId]], + order: [{ field: 'id', mode: 'DESC' }], + limit: pageSize, + offset: pageSize * (pageNum - 1) + }); jsonRes(res, { data: { pageNum, pageSize, data: searchRes.rows, - total: searchRes.rowCount + total: await PgClient.count('modelData', { + where: [['user_id', userId], 'AND', ['model_id', modelId]] + }) } }); } catch (err) { diff --git a/src/pages/api/model/data/pushModelDataInput.ts b/src/pages/api/model/data/pushModelDataInput.ts index 6fab7572a..9240a1fd3 100644 --- a/src/pages/api/model/data/pushModelDataInput.ts +++ b/src/pages/api/model/data/pushModelDataInput.ts @@ -4,7 +4,7 @@ import { connectToDatabase, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { ModelDataSchema } from '@/types/mongoSchema'; import { generateVector } from '@/service/events/generateVector'; -import { connectPg } from '@/service/pg'; +import { connectPg, PgClient } from '@/service/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -39,17 +39,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } // 插入记录 - await pg.query( - `INSERT INTO modelData (user_id, model_id, q, a, status) VALUES ${data - .map( - (item) => - `('${userId}', '${modelId}', '${item.q.replace(/\'/g, '"')}', '${item.a.replace( - /\'/g, - '"' - )}', 'waiting')` - ) - .join(',')}` - ); + await PgClient.insert('modelData', { + values: data.map((item) => [ + { key: 'user_id', value: userId }, + { key: 'model_id', value: modelId }, + { key: 'q', value: item.q }, + { key: 'a', value: item.a }, + { key: 'status', value: 'waiting' } + ]) + }); generateVector(); diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index 8d0826d71..d6ae0c7e4 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -7,7 +7,7 @@ import { ChatModelNameEnum } from '@/constants/model'; import { pushSplitDataBill } from '@/service/events/pushBill'; import { generateVector } from './generateVector'; import { openaiError2 } from '../errorCode'; -import { connectPg } from '@/service/pg'; +import { PgClient } from '@/service/pg'; import { ModelSplitDataSchema } from '@/types/mongoSchema'; export async function generateQA(next = false): Promise { @@ -22,7 +22,6 @@ export async function generateQA(next = false): Promise { let dataId = null; try { - const pg = await connectPg(); // 找出一个需要生成的 dataItem const data = await SplitData.aggregate([ { $match: { textList: { $exists: true, $ne: [] } } }, @@ -115,7 +114,7 @@ export async function generateQA(next = false): Promise { }; }) .catch((err) => { - console.log('QA 拆分错误'); + console.log('QA 拆分错误', err); return Promise.reject(err); }) ) @@ -137,18 +136,16 @@ export async function generateQA(next = false): Promise { textList: dataItem.textList.slice(0, -5) }), // 删掉后5个数据 // 生成的内容插入 pg - pg.query(`INSERT INTO modelData (user_id, model_id, q, a, status) VALUES ${resultList - .map( - (item) => - `('${String(dataItem.userId)}', '${String(dataItem.modelId)}', '${item.q.replace( - /\'/g, - '"' - )}', '${item.a.replace(/\'/g, '"')}', 'waiting')` - ) - .join(',')} - `) + PgClient.insert('modelData', { + values: resultList.map((item) => [ + { key: 'user_id', value: dataItem.userId }, + { key: 'model_id', value: dataItem.modelId }, + { key: 'q', value: item.q }, + { key: 'a', value: item.a }, + { key: 'status', value: 'waiting' } + ]) + }) ]); - console.log('生成QA成功,time:', `${(Date.now() - startTime) / 1000}s`); generateQA(true); diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 58b6f3c51..e137ffbda 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -1,8 +1,7 @@ import { connectRedis } from '../redis'; import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai'; import { openaiError2 } from '../errorCode'; -import { connectPg } from '@/service/pg'; -import type { PgModelDataItemType } from '@/types/pg'; +import { connectPg, PgClient } from '@/service/pg'; export async function generateVector(next = false): Promise { if (process.env.queueTask !== '1') { @@ -16,14 +15,12 @@ export async function generateVector(next = false): Promise { let dataId = null; try { - const pg = await connectPg(); - // 从找出一个 status = waiting 的数据 - const searchRes = await pg.query(`SELECT id, q, user_id - FROM modelData - WHERE status='waiting' - LIMIT 1 - `); + const searchRes = await PgClient.select('modelData', { + field: ['id', 'q', 'user_id'], + where: [['status', 'waiting']], + limit: 1 + }); if (searchRes.rowCount === 0) { console.log('没有需要生成 【向量】 的数据'); @@ -47,7 +44,9 @@ export async function generateVector(next = false): Promise { systemKey = res.systemKey; } catch (error: any) { if (error?.code === 501) { - await pg.query(`DELETE FROM modelData WHERE id = '${dataId}'`); + await PgClient.delete('modelData', { + where: [['id', dataId]] + }); generateVector(true); return; } @@ -64,9 +63,13 @@ export async function generateVector(next = false): Promise { }); // 更新 pg 向量和状态数据 - await pg.query( - `UPDATE modelData SET vector = '[${vector}]', status = 'ready' WHERE id = ${dataId}` - ); + await PgClient.update('modelData', { + values: [ + { key: 'vector', value: `[${vector}]` }, + { key: 'status', value: `ready` } + ], + where: [['id', dataId]] + }); console.log(`生成向量成功: ${dataItem.id}`); diff --git a/src/service/mongo.ts b/src/service/mongo.ts index 610a1c5e8..e413ac952 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -12,7 +12,6 @@ export async function connectToDatabase(): Promise { } global.mongodb = 'connecting'; - console.log('connect mongo'); try { mongoose.set('strictQuery', true); global.mongodb = await mongoose.connect(process.env.MONGODB_URI as string, { @@ -22,6 +21,7 @@ export async function connectToDatabase(): Promise { minPoolSize: 1, maxConnecting: 5 }); + console.log('mongo connected'); } catch (error) { console.log('error->', 'mongo connect error'); global.mongodb = null; diff --git a/src/service/pg.ts b/src/service/pg.ts index f1d142bc3..53c51e937 100644 --- a/src/service/pg.ts +++ b/src/service/pg.ts @@ -1,4 +1,5 @@ import { Pool } from 'pg'; +import type { QueryResultRow } from 'pg'; export const connectPg = async () => { if (global.pgClient) { @@ -30,3 +31,111 @@ export const connectPg = async () => { return Promise.reject(error); } }; + +type WhereProps = (string | [string, string | number])[]; +type GetProps = { + field?: string[]; + where?: WhereProps; + order?: { field: string; mode: 'DESC' | 'ASC' }[]; + limit?: number; + offset?: number; +}; + +type DeleteProps = { + where: WhereProps; +}; + +type ValuesProps = { key: string; value: string | number }[]; +type UpdateProps = { + values: ValuesProps; + where: WhereProps; +}; +type InsertProps = { + values: ValuesProps[]; +}; + +class Pg { + private getWhereStr(where?: WhereProps) { + return where + ? `WHERE ${where + .map((item) => { + if (typeof item === 'string') { + return item; + } + const val = typeof item[1] === 'string' ? `'${item[1]}'` : item[1]; + return `${item[0]}=${val}`; + }) + .join(' ')}` + : ''; + } + private getUpdateValStr(values: ValuesProps) { + return values + .map((item) => { + const val = + typeof item.value === 'number' + ? item.value + : `'${String(item.value).replace(/\'/g, '"')}'`; + + return `${item.key}=${val}`; + }) + .join(','); + } + private getInsertValStr(values: ValuesProps[]) { + return values + .map( + (items) => + `(${items + .map((item) => + typeof item.value === 'number' + ? item.value + : `'${String(item.value).replace(/\'/g, '"')}'` + ) + .join(',')})` + ) + .join(','); + } + async select(table: string, props: GetProps) { + const sql = `SELECT ${!props.field || props.field?.length === 0 ? '*' : props.field?.join(',')} + FROM ${table} + ${this.getWhereStr(props.where)} + ${ + props.order + ? `ORDER BY ${props.order.map((item) => `${item.field} ${item.mode}`).join(',')}` + : '' + } + LIMIT ${props.limit || 10} OFFSET ${props.offset || 0} + `; + + const pg = await connectPg(); + return pg.query(sql); + } + async count(table: string, props: GetProps) { + const sql = `SELECT COUNT(*) + FROM ${table} + ${this.getWhereStr(props.where)} + `; + const pg = await connectPg(); + return pg.query(sql).then((res) => Number(res.rows[0]?.count || 0)); + } + async delete(table: string, props: DeleteProps) { + const sql = `DELETE FROM ${table} ${this.getWhereStr(props.where)}`; + const pg = await connectPg(); + return pg.query(sql); + } + async update(table: string, props: UpdateProps) { + const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr( + props.where + )}`; + + const pg = await connectPg(); + return pg.query(sql); + } + async insert(table: string, props: InsertProps) { + const fields = props.values[0].map((item) => item.key).join(','); + const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `; + const pg = await connectPg(); + return pg.query(sql); + } +} + +export const PgClient = new Pg();