diff --git a/src/api/model.ts b/src/api/model.ts index 4a092c67d..b4ec38617 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -1,7 +1,7 @@ import { GET, POST, DELETE, PUT } from './request'; import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema'; -import { ModelUpdateParams } from '@/types/model'; -import { PagingData, RequestPaging } from '../types/index'; +import { ModelUpdateParams, ShareModelItem } from '@/types/model'; +import { RequestPaging } from '../types/index'; import { Obj2Query } from '@/utils/tools'; /** @@ -99,4 +99,13 @@ export const delOneModelData = (dataId: string) => * 获取共享市场模型 */ export const getShareModelList = (data: { searchText?: string } & RequestPaging) => - POST(`/model/share/getModels`, data); + POST(`/model/share/getModels`, data); +/** + * 获取收藏的模型 + */ +export const getCollectionModels = () => GET(`/model/share/getCollection`); +/** + * 收藏/取消收藏模型 + */ +export const triggerModelCollection = (modelId: string) => + POST(`/model/share/collection?modelId=${modelId}`); diff --git a/src/components/Icon/icons/collectionLight.svg b/src/components/Icon/icons/collectionLight.svg index 72fb923ff..97d460b02 100644 --- a/src/components/Icon/icons/collectionLight.svg +++ b/src/components/Icon/icons/collectionLight.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/src/components/Icon/icons/collectionSolid.svg b/src/components/Icon/icons/collectionSolid.svg index 140c32999..a4aa121c5 100644 --- a/src/components/Icon/icons/collectionSolid.svg +++ b/src/components/Icon/icons/collectionSolid.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/src/pages/api/model/share/collection.ts b/src/pages/api/model/share/collection.ts new file mode 100644 index 000000000..8d2130502 --- /dev/null +++ b/src/pages/api/model/share/collection.ts @@ -0,0 +1,44 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, Collection, Model } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; + +/* 模型收藏切换 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { modelId } = req.query as { modelId: string }; + + if (!modelId) { + throw new Error('缺少参数'); + } + // 凭证校验 + const userId = await authToken(req.headers.authorization); + + await connectToDatabase(); + + const collectionRecord = await Collection.findOne({ + userId, + modelId + }); + + if (collectionRecord) { + await Collection.findByIdAndRemove(collectionRecord._id); + } else { + await Collection.create({ + userId, + modelId + }); + } + + await Model.findByIdAndUpdate(modelId, { + 'share.collection': await Collection.countDocuments({ modelId }) + }); + + jsonRes(res); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/share/getCollection.ts b/src/pages/api/model/share/getCollection.ts new file mode 100644 index 000000000..40c89098b --- /dev/null +++ b/src/pages/api/model/share/getCollection.ts @@ -0,0 +1,38 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, Collection } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import type { ShareModelItem } from '@/types/model'; + +/* 获取模型列表 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + // 凭证校验 + const userId = await authToken(req.headers.authorization); + + await connectToDatabase(); + + // get my collections + const collections = await Collection.find({ + userId + }).populate('modelId', '_id avatar name userId share'); + + jsonRes(res, { + data: collections + .map((item: any) => ({ + _id: item.modelId?._id, + avatar: item.modelId?.avatar || '', + name: item.modelId?.name || '', + userId: item.modelId?.userId || '', + share: item.modelId?.share || {}, + isCollection: true + })) + .filter((item) => item.share.isShare) + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/share/getModels.ts b/src/pages/api/model/share/getModels.ts index c5a4233fd..13ae72f53 100644 --- a/src/pages/api/model/share/getModels.ts +++ b/src/pages/api/model/share/getModels.ts @@ -1,8 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase } from '@/service/mongo'; +import { connectToDatabase, Collection, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { Model } from '@/service/models/model'; import type { PagingData } from '@/types'; import type { ShareModelItem } from '@/types/model'; @@ -30,19 +29,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< }; // 根据分享的模型 - const models = await Model.find(where, '_id avatar name userId share') - .sort({ - 'share.collection': -1 - }) - .limit(pageSize) - .skip((pageNum - 1) * pageSize); + const [models, total] = await Promise.all([ + Model.find(where, '_id avatar name userId share') + .sort({ + 'share.collection': -1 + }) + .limit(pageSize) + .skip((pageNum - 1) * pageSize), + Model.countDocuments(where) + ]); jsonRes>(res, { data: { pageNum, pageSize, - data: models, - total: await Model.countDocuments(where) + data: models.map((item) => ({ + _id: item._id, + avatar: item.avatar, + name: item.name, + userId: item.userId, + share: item.share, + isCollection: false + })), + total } }); } catch (err) { diff --git a/src/pages/chat/components/SlideBar.tsx b/src/pages/chat/components/SlideBar.tsx index 6aa2c9093..303deb127 100644 --- a/src/pages/chat/components/SlideBar.tsx +++ b/src/pages/chat/components/SlideBar.tsx @@ -1,4 +1,4 @@ -import React, { useRef, useEffect } from 'react'; +import React, { useRef, useEffect, useMemo } from 'react'; import { AddIcon, ChatIcon, DeleteIcon, MoonIcon, SunIcon } from '@chakra-ui/icons'; import { Box, @@ -22,6 +22,7 @@ import { getToken } from '@/utils/user'; import MyIcon from '@/components/Icon'; import WxConcat from '@/components/WxConcat'; import { getChatHistory, delChatHistoryById } from '@/api/chat'; +import { getCollectionModels } from '@/api/model'; import { modelList } from '@/constants/model'; const SlideBar = ({ @@ -45,10 +46,30 @@ const SlideBar = ({ cacheTime: 5 * 60 * 1000 }); + const { data: collectionModels = [] } = useQuery([getCollectionModels], getCollectionModels); + + const models = useMemo(() => { + const myModelList = myModels.map((item) => ({ + id: item._id, + name: item.name, + icon: modelList.find((model) => model.model === item?.service?.modelName)?.icon || 'model' + })); + const collectionList = collectionModels + .map((item) => ({ + id: item._id, + name: item.name, + icon: 'collectionSolid' as any + })) + .filter((model) => !myModelList.find((item) => item.id === model.id)); + + return myModelList.concat(collectionList); + }, [collectionModels, myModels]); + const { data: chatHistory = [], mutate: loadChatHistory } = useMutation({ mutationFn: getChatHistory }); + // update history useEffect(() => { if (chatId && preChatId.current === '') { loadChatHistory(); @@ -56,8 +77,11 @@ const SlideBar = ({ preChatId.current = chatId; }, [chatId, loadChatHistory]); + // init history useEffect(() => { - loadChatHistory(); + setTimeout(() => { + loadChatHistory(); + }, 1000); }, [loadChatHistory]); const RenderHistory = () => ( @@ -165,9 +189,9 @@ const SlideBar = ({ {isSuccess && ( <> - {myModels.map((item) => ( + {models.map((item) => ( { - if (item._id === modelId) return; - resetChat(item._id); + if (item.id === modelId) return; + resetChat(item.id); onClose(); }} > - model.model === item.service.modelName)?.icon || - 'model' - } - mr={2} - color={'white'} - w={'16px'} - h={'16px'} - /> + {item.name} diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 387574606..81212c768 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -467,7 +467,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { borderBottom={'1px solid rgba(0,0,0,0.1)'} > - + {'avatar'} { +const ShareModelList = ({ + models = [], + onclickCollection +}: { + models: ShareModelItem[]; + onclickCollection: (modelId: string) => void; +}) => { const router = useRouter(); return ( <> {models.map((model) => ( - + {'avatar'} {model.name} - + {model.share.intro || '这个模型没有介绍~'} - - + onclickCollection(model._id)} + > + {model.share.collection} @@ -53,7 +74,7 @@ const ShareModelList = ({ models }: { models: ShareModelItem[] }) => { )} - + ))} ); diff --git a/src/pages/model/share/index.tsx b/src/pages/model/share/index.tsx index 78934b5b4..653103f8a 100644 --- a/src/pages/model/share/index.tsx +++ b/src/pages/model/share/index.tsx @@ -1,24 +1,20 @@ -import React, { useState, useRef } from 'react'; +import React, { useState, useRef, useCallback, useMemo } from 'react'; import { Box, Flex, Card, Grid, Input } from '@chakra-ui/react'; import { useLoading } from '@/hooks/useLoading'; -import { getShareModelList } from '@/api/model'; +import { getShareModelList, triggerModelCollection, getCollectionModels } from '@/api/model'; import { usePagination } from '@/hooks/usePagination'; import type { ShareModelItem } from '@/types/model'; import ShareModelList from './components/list'; +import { useQuery } from '@tanstack/react-query'; const modelList = () => { - const { Loading, setIsLoading } = useLoading(); + const { Loading } = useLoading(); const lastSearch = useRef(''); const [searchText, setSearchText] = useState(''); /* 加载模型 */ - const { - data: models, - isLoading, - Pagination, - getData - } = usePagination({ + const { data, isLoading, Pagination, getData, pageNum } = usePagination({ api: getShareModelList, pageSize: 20, params: { @@ -26,47 +22,90 @@ const modelList = () => { } }); + const { data: collectionModels = [], refetch: refetchCollection } = useQuery( + [getCollectionModels], + getCollectionModels + ); + + const models = useMemo(() => { + if (!collectionModels) return []; + return data.map((model) => ({ + ...model, + isCollection: !!collectionModels.find((item) => item._id === model._id) + })); + }, [collectionModels, data]); + + const onclickCollection = useCallback( + async (modelId: string) => { + try { + await triggerModelCollection(modelId); + getData(pageNum); + refetchCollection(); + } catch (error) { + console.log(error); + } + }, + [getData, pageNum, refetchCollection] + ); + return ( - - {/* 头部 */} + <> - 模型共享市场 + 我收藏的模型 - (Beta) - setSearchText(e.target.value)} - onBlur={() => { - if (searchText === lastSearch.current) return; - getData(1); - lastSearch.current = searchText; - }} - onKeyDown={(e) => { - if (searchText === lastSearch.current) return; - if (e.key === 'Enter') { - getData(1); - lastSearch.current = searchText; - } - }} - /> + {collectionModels.length == 0 && ( + + 还没有收藏模型~ + + )} + + + - - - - - - - + + + + 模型共享市场{' '} + + (Beta) + + + + setSearchText(e.target.value)} + onBlur={() => { + if (searchText === lastSearch.current) return; + getData(1); + lastSearch.current = searchText; + }} + onKeyDown={(e) => { + if (searchText === lastSearch.current) return; + if (e.key === 'Enter') { + getData(1); + lastSearch.current = searchText; + } + }} + /> + + + + + + + + + - + ); }; diff --git a/src/service/models/collection.ts b/src/service/models/collection.ts new file mode 100644 index 000000000..3094a50e1 --- /dev/null +++ b/src/service/models/collection.ts @@ -0,0 +1,18 @@ +import { Schema, model, models, Model as MongoModel } from 'mongoose'; +import { CollectionSchema as CollectionType } from '@/types/mongoSchema'; + +const CollectionSchema = new Schema({ + userId: { + type: Schema.Types.ObjectId, + ref: 'user', + required: true + }, + modelId: { + type: Schema.Types.ObjectId, + ref: 'model', + required: true + } +}); + +export const Collection: MongoModel = + models['collection'] || model('collection', CollectionSchema); diff --git a/src/service/mongo.ts b/src/service/mongo.ts index ae1de9e99..2645eced3 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -50,3 +50,4 @@ export * from './models/pay'; export * from './models/splitData'; export * from './models/openapi'; export * from './models/promotionRecord'; +export * from './models/collection'; diff --git a/src/types/model.d.ts b/src/types/model.d.ts index 5462be14d..9315f858d 100644 --- a/src/types/model.d.ts +++ b/src/types/model.d.ts @@ -26,4 +26,5 @@ export interface ShareModelItem { name: string; userId: string; share: ModelSchema['share']; + isCollection: boolean; } diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index c8513e099..6e09257b3 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -64,6 +64,11 @@ export interface ModelPopulate extends ModelSchema { userId: UserModelSchema; } +export interface CollectionSchema { + modelId: string; + userId: string; +} + export type ModelDataType = 0 | 1; export interface ModelDataSchema { _id: string;