From 3a661bac3496eba67fd248156d30358187e87dc0 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 1 Nov 2024 14:34:20 -0400 Subject: [PATCH] fix(ui): exclude submodels from model manager --- .../listeners/modelsLoaded.ts | 8 +-- .../subpanels/ModelManagerPanel/ModelList.tsx | 6 +- .../src/services/api/hooks/modelsByType.ts | 22 +++++-- .../frontend/web/src/services/api/types.ts | 59 +++++++++++++------ 4 files changed, 63 insertions(+), 32 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 43e08bc211..a3cc9c31ac 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => { // We have a VAE selected, need to check if it is available // Grab just the VAE models - const vaeModels = models.filter(isNonFluxVAEModelConfig); + const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m)); // If the current VAE model is available, we don't need to do anything if (vaeModels.some((m) => m.key === selectedVAEModel.key)) { @@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => { const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => { const selectedT5EncoderModel = state.params.t5EncoderModel; - const t5EncoderModels = models.filter(isT5EncoderModelConfig); + const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m)); // If the currently selected model is available, we don't need to do anything if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) { @@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => { const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => { const selectedCLIPEmbedModel = state.params.clipEmbedModel; - const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig); + const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m)); // If the currently selected model is available, we don't need to do anything if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) { @@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => { const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => { const selectedFLUXVAEModel = state.params.fluxVAE; - const fluxVAEModels = models.filter(isFluxVAEModelConfig); + const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m)); // If the currently selected model is available, we don't need to do anything if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 78e5012484..8f546f9d0c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -80,19 +80,19 @@ const ModelList = () => { [clipVisionModels, searchTerm, filteredModelType] ); - const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels(); + const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels({ excludeSubmodels: true }); const filteredVAEModels = useMemo( () => modelsFilter(vaeModels, searchTerm, filteredModelType), [vaeModels, searchTerm, filteredModelType] ); - const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels(); + const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels({ excludeSubmodels: true }); const filteredT5EncoderModels = useMemo( () => modelsFilter(t5EncoderModels, searchTerm, filteredModelType), [t5EncoderModels, searchTerm, filteredModelType] ); - const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels(); + const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true }); const filteredClipEmbedModels = useMemo( () => modelsFilter(clipEmbedModels, searchTerm, filteredModelType), [clipEmbedModels, searchTerm, filteredModelType] diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 84d6be1336..7185b01f3b 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -30,8 +30,13 @@ import { isVAEModelConfig, } from 'services/api/types'; +type ModelHookArgs = { excludeSubmodels?: boolean }; + const buildModelsHook = - (typeGuard: (config: AnyModelConfig) => config is T) => + ( + typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T, + excludeSubmodels?: boolean + ) => () => { const result = useGetModelConfigsQuery(undefined); const modelConfigs = useMemo(() => { @@ -39,7 +44,9 @@ const buildModelsHook = return EMPTY_ARRAY; } - return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard); + return modelConfigsAdapterSelectors + .selectAll(result.data) + .filter((config) => typeGuard(config, excludeSubmodels)); }, [result]); return [modelConfigs, result] as const; @@ -56,13 +63,16 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig); export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); export const useControlNetModels = buildModelsHook(isControlNetModelConfig); export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); -export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig); -export const useCLIPEmbedModels = buildModelsHook(isCLIPEmbedModelConfig); +export const useT5EncoderModels = (args?: ModelHookArgs) => + buildModelsHook(isT5EncoderModelConfig, args?.excludeSubmodels)(); +export const useCLIPEmbedModels = (args?: ModelHookArgs) => + buildModelsHook(isCLIPEmbedModelConfig, args?.excludeSubmodels)(); export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig); export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig); -export const useVAEModels = buildModelsHook(isVAEModelConfig); -export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig); +export const useVAEModels = (args?: ModelHookArgs) => buildModelsHook(isVAEModelConfig, args?.excludeSubmodels)(); +export const useFluxVAEModels = (args?: ModelHookArgs) => + buildModelsHook(isFluxVAEModelConfig, args?.excludeSubmodels)(); export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig); // const buildModelsSelector = diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3dfb23b982..39f587c849 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -103,19 +103,24 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo return config.type === 'lora'; }; -export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => { - return config.type === 'vae' || (config.type === 'main' && check_submodels(['vae'], config)); +export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => { + return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && check_submodels(['vae'], config)); }; -export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => { +export const isNonFluxVAEModelConfig = ( + config: AnyModelConfig, + excludeSubmodels?: boolean +): config is VAEModelConfig => { return ( - (config.type === 'vae' || (config.type === 'main' && check_submodels(['vae'], config))) && config.base !== 'flux' + (config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && check_submodels(['vae'], config))) && + config.base !== 'flux' ); }; -export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => { +export const isFluxVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => { return ( - (config.type === 'vae' || (config.type === 'main' && check_submodels(['vae'], config))) && config.base === 'flux' + (config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && check_submodels(['vae'], config))) && + config.base === 'flux' ); }; @@ -136,26 +141,42 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd }; export const isT5EncoderModelConfig = ( - config: AnyModelConfig + config: AnyModelConfig, + excludeSubmodels?: boolean ): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => { - return config.type === 't5_encoder' || (config.type === 'main' && check_submodels(['t5_encoder'], config)); -}; - -export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => { - return config.type === 'clip_embed' || (config.type === 'main' && check_submodels(['clip_embed'], config)); -}; - -export const isCLIPLEmbedModelConfig = (config: AnyModelConfig): config is CLIPLEmbedModelConfig => { return ( - (config.type === 'clip_embed' && config.variant === 'large') || - (config.type === 'main' && check_submodels(['clip_embed', 'large'], config)) + config.type === 't5_encoder' || + (!excludeSubmodels && config.type === 'main' && check_submodels(['t5_encoder'], config)) ); }; -export const isCLIPGEmbedModelConfig = (config: AnyModelConfig): config is CLIPGEmbedModelConfig => { +export const isCLIPEmbedModelConfig = ( + config: AnyModelConfig, + excludeSubmodels?: boolean +): config is CLIPEmbedModelConfig => { + return ( + config.type === 'clip_embed' || + (!excludeSubmodels && config.type === 'main' && check_submodels(['clip_embed'], config)) + ); +}; + +export const isCLIPLEmbedModelConfig = ( + config: AnyModelConfig, + excludeSubmodels?: boolean +): config is CLIPLEmbedModelConfig => { + return ( + (config.type === 'clip_embed' && config.variant === 'large') || + (!excludeSubmodels && config.type === 'main' && check_submodels(['clip_embed', 'large'], config)) + ); +}; + +export const isCLIPGEmbedModelConfig = ( + config: AnyModelConfig, + excludeSubmodels?: boolean +): config is CLIPGEmbedModelConfig => { return ( (config.type === 'clip_embed' && config.variant === 'gigantic') || - (config.type === 'main' && check_submodels(['clip_embed', 'gigantic'], config)) + (!excludeSubmodels && config.type === 'main' && check_submodels(['clip_embed', 'gigantic'], config)) ); };