From ce687b28efd829c2e4ce66da657ae20dd5509645 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 31 Jul 2023 12:51:30 +1200 Subject: [PATCH] fix: Model Manager Tab Issues --- .../subpanels/ModelManagerPanel/ModelList.tsx | 131 +++++++++++++----- 1 file changed, 98 insertions(+), 33 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 3f100d9072..f29b9c1086 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,4 +1,4 @@ -import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; +import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react'; import { EntityState } from '@reduxjs/toolkit'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; @@ -6,23 +6,23 @@ import { forEach } from 'lodash-es'; import type { ChangeEvent, PropsWithChildren } from 'react'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { + LoRAModelConfigEntity, MainModelConfigEntity, OnnxModelConfigEntity, + useGetLoRAModelsQuery, useGetMainModelsQuery, useGetOnnxModelsQuery, - useGetLoRAModelsQuery, - LoRAModelConfigEntity, } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; -import { ALL_BASE_MODELS } from 'services/api/constants'; type ModelListProps = { selectedModelId: string | undefined; setSelectedModelId: (name: string | undefined) => void; }; -type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; +type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx'; @@ -33,35 +33,43 @@ const ModelList = (props: ModelListProps) => { const { t } = useTranslation(); const [nameFilter, setNameFilter] = useState(''); const [modelFormatFilter, setModelFormatFilter] = - useState('images'); + useState('all'); - const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredDiffusersModels: modelsFilter( - data, - 'main', - 'diffusers', - nameFilter - ), - }), - }); + const { filteredDiffusersModels, isDiffusersModelLoading } = + useGetMainModelsQuery(ALL_BASE_MODELS, { + selectFromResult: ({ data, isLoading }) => ({ + filteredDiffusersModels: modelsFilter( + data, + 'main', + 'diffusers', + nameFilter + ), + isDiffusersModelLoading: isLoading, + }), + }); - const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredCheckpointModels: modelsFilter( - data, - 'main', - 'checkpoint', - nameFilter - ), - }), - }); + const { filteredCheckpointModels, isCheckpointModelLoading } = + useGetMainModelsQuery(ALL_BASE_MODELS, { + selectFromResult: ({ data, isLoading }) => ({ + filteredCheckpointModels: modelsFilter( + data, + 'main', + 'checkpoint', + nameFilter + ), + isCheckpointModelLoading: isLoading, + }), + }); - const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, { - selectFromResult: ({ data }) => ({ - filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), - }), - }); + const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery( + undefined, + { + selectFromResult: ({ data, isLoading }) => ({ + filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), + isLoadingLoraModels: isLoading, + }), + } + ); const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ @@ -79,13 +87,47 @@ const ModelList = (props: ModelListProps) => { setNameFilter(e.target.value); }, []); + const renderModelList = ( + filterArray: Partial[], + isLoading: boolean, + loadingMessage: string, + title: string, + modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[] + ) => { + if (!filterArray.includes(modelFormatFilter)) return; + + if (isLoading) { + return ; + } + + if (modelList.length === 0) return; + + return ( + + + + {title} + + {modelList.map((model) => ( + + ))} + + + ); + }; + return ( setModelFormatFilter('images')} - isChecked={modelFormatFilter === 'images'} + onClick={() => setModelFormatFilter('all')} + isChecked={modelFormatFilter === 'all'} size="sm" > {t('modelManager.allModels')} @@ -287,3 +329,26 @@ const StyledModelContainer = (props: PropsWithChildren) => { ); }; + +const FetchingModelsLoader = ({ + loadingMessage, +}: { + loadingMessage?: string; +}) => { + return ( + + + + + {loadingMessage ? loadingMessage : 'Fetching...'} + + + + ); +};