From edfd90f2a4ae2044b04a58d7ab49567ae2a01b73 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:38:12 +1000 Subject: [PATCH] tidy(mm): consistent class names --- .../app/invocations/create_gradient_mask.py | 4 +- invokeai/app/invocations/flux_model_loader.py | 4 +- .../model_install/model_install_default.py | 4 +- .../app/services/shared/invocation_context.py | 4 +- invokeai/backend/model_manager/__init__.py | 4 +- invokeai/backend/model_manager/config.py | 192 +++++++++--------- .../model_manager/load/load_default.py | 4 +- .../load/model_loader_registry.py | 6 +- .../load/model_loaders/clip_vision.py | 4 +- .../load/model_loaders/cogview4.py | 8 +- .../model_manager/load/model_loaders/flux.py | 8 +- .../load/model_loaders/generic_diffusers.py | 4 +- .../load/model_loaders/stable_diffusion.py | 8 +- tests/test_model_probe.py | 22 +- 14 files changed, 137 insertions(+), 139 deletions(-) diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index b232fbbc93..f6e046d096 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -21,7 +21,7 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import LoadedModel -from invokeai.backend.model_manager.config import MainConfigBase +from invokeai.backend.model_manager.config import Main_Config_Base from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor @@ -182,7 +182,7 @@ class CreateGradientMaskInvocation(BaseInvocation): if self.unet is not None and self.vae is not None and self.image is not None: # all three fields must be present at the same time main_model_config = context.models.get_config(self.unet.unet.key) - assert isinstance(main_model_config, MainConfigBase) + assert isinstance(main_model_config, Main_Config_Base) if main_model_config.variant is ModelVariantType.Inpaint: mask = dilated_mask_tensor vae_info: LoadedModel = context.models.load(self.vae.vae) diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index 4ed3b91bc6..2803db48e0 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -15,7 +15,7 @@ from invokeai.app.util.t5_model_identifier import ( ) from invokeai.backend.flux.util import get_flux_max_seq_length from invokeai.backend.model_manager.config import ( - CheckpointConfigBase, + Checkpoint_Config_Base, ) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @@ -87,7 +87,7 @@ class FluxModelLoaderInvocation(BaseInvocation): t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model) transformer_config = context.models.get_config(transformer) - assert isinstance(transformer_config, CheckpointConfigBase) + assert isinstance(transformer_config, Checkpoint_Config_Base) return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 06608df8e8..10a954a563 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -37,7 +37,7 @@ from invokeai.app.services.model_records import DuplicateModelException, ModelRe from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( AnyModelConfig, - CheckpointConfigBase, + Checkpoint_Config_Base, InvalidModelConfigException, ModelConfigFactory, ) @@ -625,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase): info.path = model_path.as_posix() - if isinstance(info, CheckpointConfigBase) and info.config_path is not None: + if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None: # Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the # invoke-managed legacy config dir, we use a relative path. legacy_config_path = self.app_config.legacy_conf_path / info.config_path diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 743b6208ea..16aacbb985 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.step_callback import diffusion_step_callback from invokeai.backend.model_manager.config import ( AnyModelConfig, - ModelConfigBase, + Config_Base, ) from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType @@ -558,7 +558,7 @@ class ModelsInterface(InvocationContextInterface): The absolute path to the model. """ - model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path) + model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path) if model_path.is_absolute(): return model_path.resolve() diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index dca72f170e..a167687d2e 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -3,7 +3,7 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, InvalidModelConfigException, - ModelConfigBase, + Config_Base, ModelConfigFactory, ) from invokeai.backend.model_manager.legacy_probe import ModelProbe @@ -30,7 +30,7 @@ __all__ = [ "ModelConfigFactory", "ModelProbe", "ModelSearch", - "ModelConfigBase", + "Config_Base", "AnyModel", "AnyVariant", "BaseModelType", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index c82b0673bd..188ac9ad11 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -35,7 +35,6 @@ from typing import ( Optional, Self, Type, - TypeAlias, Union, ) @@ -321,7 +320,7 @@ class LegacyProbeMixin: pass -class ModelConfigBase(ABC, BaseModel): +class Config_Base(ABC, BaseModel): """ Abstract base class for model configurations. A model config describes a specific combination of model base, type and format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format @@ -427,7 +426,7 @@ class ModelConfigBase(ABC, BaseModel): Computes the discriminator value for a model config. https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator """ - if isinstance(v, ModelConfigBase): + if isinstance(v, Config_Base): # We have an instance of a ModelConfigBase subclass - use its tag directly. return v.get_tag().tag if isinstance(v, dict): @@ -473,7 +472,7 @@ class ModelConfigBase(ABC, BaseModel): raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") -class Unknown_Config(ModelConfigBase): +class Unknown_Config(Config_Base): """Model config for unknown models, used as a fallback when we cannot identify a model.""" base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) @@ -485,7 +484,7 @@ class Unknown_Config(ModelConfigBase): raise NotAMatch(cls, "unknown model config cannot match any model") -class CheckpointConfigBase(ABC, BaseModel): +class Checkpoint_Config_Base(ABC, BaseModel): """Base class for checkpoint-style models.""" config_path: str | None = Field( @@ -498,7 +497,7 @@ class CheckpointConfigBase(ABC, BaseModel): ) -class DiffusersConfigBase(ABC, BaseModel): +class Diffusers_Config_Base(ABC, BaseModel): """Base class for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -521,7 +520,7 @@ class DiffusersConfigBase(ABC, BaseModel): return ModelRepoVariant.Default -class T5Encoder_T5Encoder_Config(ModelConfigBase): +class T5Encoder_T5Encoder_Config(Config_Base): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) @@ -552,7 +551,7 @@ class T5Encoder_T5Encoder_Config(ModelConfigBase): raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") -class T5Encoder_BnBLLMint8_Config(ModelConfigBase): +class T5Encoder_BnBLLMint8_Config(Config_Base): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) @@ -590,7 +589,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase): raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") -class LoRAConfigBase(ABC, BaseModel): +class LoRA_Config_Base(ABC, BaseModel): """Base class for LoRA models.""" type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) @@ -609,7 +608,7 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: return value -class LoRA_OMI_Config_Base(LoRAConfigBase): +class LoRA_OMI_Config_Base(LoRA_Config_Base): format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) @classmethod @@ -663,15 +662,15 @@ class LoRA_OMI_Config_Base(LoRAConfigBase): raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") -class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, ModelConfigBase): +class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, ModelConfigBase): +class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class LoRA_LyCORIS_Config_Base(LoRAConfigBase): +class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): """Model config for LoRA/Lycoris models.""" type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) @@ -750,27 +749,27 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase): raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") -class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): +class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): +class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): +class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): +class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class ControlAdapterConfigBase(ABC, BaseModel): +class ControlAdapter_Config_Base(ABC, BaseModel): default_settings: ControlAdapterDefaultSettings | None = Field(None) -class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase): +class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base): """Model config for Control LoRA models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) @@ -797,7 +796,7 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase) raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") -class LoRA_Diffusers_Config_Base(LoRAConfigBase): +class LoRA_Diffusers_Config_Base(LoRA_Config_Base): """Model config for LoRA/Diffusers models.""" # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates @@ -855,23 +854,23 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase): raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") -class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): +class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): +class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): +class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): +class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class VAE_Checkpoint_Config_Base(CheckpointConfigBase): +class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base): """Model config for standalone VAE models.""" type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) @@ -925,23 +924,23 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase): raise NotAMatch(cls, "cannot determine base type") -class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): +class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): +class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): +class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): +class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class VAE_Diffusers_Config_Base(DiffusersConfigBase): +class VAE_Diffusers_Config_Base(Diffusers_Config_Base): """Model config for standalone VAE models (diffusers version).""" type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) @@ -1005,15 +1004,15 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase): return BaseModelType.StableDiffusion1 -class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, ModelConfigBase): +class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, ModelConfigBase): +class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase): +class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) @@ -1068,23 +1067,23 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}") -class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): +class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): +class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): +class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): +class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConfigBase): +class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) @@ -1161,19 +1160,19 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf raise NotAMatch(cls, "unable to determine base type from state dict") -class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): +class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): +class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): +class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): +class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) @@ -1267,15 +1266,15 @@ class TI_File_Config_Base(TI_Config_Base): return cls(**fields) -class TI_File_SD1_Config(TI_File_Config_Base, ModelConfigBase): +class TI_File_SD1_Config(TI_File_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class TI_File_SD2_Config(TI_File_Config_Base, ModelConfigBase): +class TI_File_SD2_Config(TI_File_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class TI_File_SDXL_Config(TI_File_Config_Base, ModelConfigBase): +class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) @@ -1298,19 +1297,19 @@ class TI_Folder_Config_Base(TI_Config_Base): raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") -class TI_Folder_SD1_Config(TI_Folder_Config_Base, ModelConfigBase): +class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class TI_Folder_SD2_Config(TI_Folder_Config_Base, ModelConfigBase): +class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class TI_Folder_SDXL_Config(TI_Folder_Config_Base, ModelConfigBase): +class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class MainConfigBase(ABC, BaseModel): +class Main_Config_Base(ABC, BaseModel): type: Literal[ModelType.Main] = Field(default=ModelType.Main) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[MainModelDefaultSettings] = Field( @@ -1352,7 +1351,7 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: return False -class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): +class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): """Model config for main checkpoint models.""" format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @@ -1450,19 +1449,19 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): raise NotAMatch(cls, "state dict does not look like a main model") -class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, ModelConfigBase): +class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, ModelConfigBase): +class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, ModelConfigBase): +class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, ModelConfigBase): +class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) @@ -1511,7 +1510,7 @@ def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | Non return FluxVariantType.Schnell -class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): """Model config for main checkpoint models.""" format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @@ -1581,7 +1580,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon raise NotAMatch(cls, "state dict looks like GGUF quantized") -class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) @@ -1630,7 +1629,7 @@ class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigB raise NotAMatch(cls, "state dict does not look like bnb quantized nf4") -class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) @@ -1679,7 +1678,7 @@ class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBas raise NotAMatch(cls, "state dict does not look like GGUF quantized") -class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase): +class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): prediction_type: SchedulerPredictionType = Field() variant: ModelVariantType = Field() @@ -1785,23 +1784,23 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase): raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") -class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, ModelConfigBase): +class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) -class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, ModelConfigBase): +class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) -class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, ModelConfigBase): +class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) -class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, ModelConfigBase): +class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) -class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase): +class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) @classmethod @@ -1875,7 +1874,7 @@ class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfig return submodels -class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase): +class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) @classmethod @@ -1901,11 +1900,11 @@ class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelC ) -class IPAdapterConfigBase(ABC, BaseModel): +class IPAdapter_Config_Base(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) -class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase): +class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base): """Model config for IP Adapter diffusers format models.""" format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) @@ -1968,19 +1967,19 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") -class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): +class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): +class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): +class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase): +class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base): """Model config for IP Adapter checkpoint format models.""" format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @@ -2041,19 +2040,19 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") -class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): +class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): +class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): +class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): +class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) @@ -2071,7 +2070,7 @@ def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantTyp return None -class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase): +class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -2119,15 +2118,15 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase): raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}") -class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase): +class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) -class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase): +class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) -class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): +class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base): """Model config for CLIPVision.""" base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) @@ -2151,7 +2150,7 @@ class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase): +class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) @@ -2198,15 +2197,15 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'") -class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase): +class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase): +class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class Spandrel_Checkpoint_Config(ModelConfigBase): +class Spandrel_Checkpoint_Config(Config_Base): """Model config for Spandrel Image to Image models.""" base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) @@ -2239,7 +2238,7 @@ class Spandrel_Checkpoint_Config(ModelConfigBase): raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e -class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): +class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base): """Model config for SigLIP.""" type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) @@ -2263,7 +2262,7 @@ class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -class FLUXRedux_Checkpoint_Config(ModelConfigBase): +class FLUXRedux_Checkpoint_Config(Config_Base): """Model config for FLUX Tools Redux model.""" type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) @@ -2282,7 +2281,7 @@ class FLUXRedux_Checkpoint_Config(ModelConfigBase): return cls(**fields) -class LlavaOnevision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): +class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base): """Model config for Llava Onevision models.""" type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) @@ -2316,23 +2315,23 @@ class ExternalAPI_Config_Base(ABC, BaseModel): raise NotAMatch(cls, "External API models cannot be built from disk") -class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): +class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) -class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): +class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) -class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): +class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) -class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): +class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) -class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): +class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) @@ -2344,11 +2343,11 @@ class VideoConfigBase(ABC, BaseModel): ) -class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase): +class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) -class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase): +class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) @@ -2447,11 +2446,10 @@ AnyModelConfig = Annotated[ # Unknown model (fallback) Annotated[Unknown_Config, Unknown_Config.get_tag()], ], - Discriminator(ModelConfigBase.get_model_discriminator_value), + Discriminator(Config_Base.get_model_discriminator_value), ] AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) -AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] class ModelConfigFactory: @@ -2459,7 +2457,7 @@ class ModelConfigFactory: def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig: """Return the appropriate config object from raw dict values.""" model = AnyModelConfigValidator.validate_python(model_data) - if isinstance(model, CheckpointConfigBase) and timestamp: + if isinstance(model, Checkpoint_Config_Base) and timestamp: model.converted_at = timestamp validate_hash(model.hash) return model @@ -2533,7 +2531,7 @@ class ModelConfigFactory: # Try to build an instance of each model config class that uses the classify API. # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. # Other exceptions may be raised if something unexpected happens during matching or building. - for config_class in ModelConfigBase.CONFIG_CLASSES: + for config_class in Config_Base.CONFIG_CLASSES: class_name = config_class.__name__ try: instance = config_class.from_model_on_disk(mod, fields) @@ -2550,7 +2548,7 @@ class ModelConfigFactory: results[class_name] = e logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") - matches = [r for r in results.values() if isinstance(r, ModelConfigBase)] + matches = [r for r in results.values() if isinstance(r, Config_Base)] if not matches and app_config.allow_unknown_models: logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 3c26a956b7..139a7d2940 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException +from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key @@ -90,7 +90,7 @@ class ModelLoader(ModelLoaderBase): return calc_model_size_by_fs( model_path=model_path, subfolder=submodel_type.value if submodel_type else None, - variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None, + variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None, ) # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py index ecc4d1fe93..9b242fe167 100644 --- a/invokeai/backend/model_manager/load/model_loader_registry.py +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, TypeVar from invokeai.backend.model_manager.config import ( AnyModelConfig, - ModelConfigBase, + Config_Base, ) from invokeai.backend.model_manager.load import ModelLoaderBase from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType @@ -40,7 +40,7 @@ class ModelLoaderRegistryBase(ABC): @abstractmethod def get_implementation( cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + ) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]: """ Get subclass of ModelLoaderBase registered to handle base and type. @@ -84,7 +84,7 @@ class ModelLoaderRegistry(ModelLoaderRegistryBase): @classmethod def get_implementation( cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + ) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type diff --git a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py index 29d7bc691c..9065e51fbf 100644 --- a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py +++ b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py @@ -5,7 +5,7 @@ from transformers import CLIPVisionModelWithProjection from invokeai.backend.model_manager.config import ( AnyModelConfig, - DiffusersConfigBase, + Diffusers_Config_Base, ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry @@ -21,7 +21,7 @@ class ClipVisionLoader(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, DiffusersConfigBase): + if not isinstance(config, Diffusers_Config_Base): raise ValueError("Only DiffusersConfigBase models are currently supported here.") if submodel_type is not None: diff --git a/invokeai/backend/model_manager/load/model_loaders/cogview4.py b/invokeai/backend/model_manager/load/model_loaders/cogview4.py index e7669a33c4..a1a9269edb 100644 --- a/invokeai/backend/model_manager/load/model_loaders/cogview4.py +++ b/invokeai/backend/model_manager/load/model_loaders/cogview4.py @@ -5,8 +5,8 @@ import torch from invokeai.backend.model_manager.config import ( AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, + Checkpoint_Config_Base, + Diffusers_Config_Base, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader @@ -28,7 +28,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): + if isinstance(config, Checkpoint_Config_Base): raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.") if submodel_type is None: @@ -36,7 +36,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader): model_path = Path(config.path) load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None model_path = model_path / submodel_type.value diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 9340cdd21a..07967c7c56 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -36,7 +36,7 @@ from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params from invokeai.backend.model_manager.config import ( AnyModelConfig, - CheckpointConfigBase, + Checkpoint_Config_Base, CLIPEmbed_Diffusers_Config_Base, ControlNet_Checkpoint_Config_Base, ControlNet_Diffusers_Config_Base, @@ -211,7 +211,7 @@ class FluxCheckpointModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: @@ -253,7 +253,7 @@ class FluxGGUFCheckpointModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: @@ -299,7 +299,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 8a690583d5..407a116b68 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -8,7 +8,7 @@ from typing import Any, Optional from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException +from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( @@ -33,7 +33,7 @@ class GenericDiffusersLoader(ModelLoader): model_class = self.get_hf_load_class(model_path) if submodel_type is not None: raise Exception(f"There are no submodels in models of type {model_class}") - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None try: result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index c8c751134c..647ad4dbf4 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -13,8 +13,8 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpain from invokeai.backend.model_manager.config import ( AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, + Checkpoint_Config_Base, + Diffusers_Config_Base, Main_Checkpoint_SD1_Config, Main_Checkpoint_SD2_Config, Main_Checkpoint_SDXL_Config, @@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): + if isinstance(config, Checkpoint_Config_Base): return self._load_from_singlefile(config, submodel_type) if submodel_type is None: @@ -73,7 +73,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): model_path = Path(config.path) load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None model_path = model_path / submodel_type.value try: diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index e24a3ac8bd..03a7428382 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -15,7 +15,7 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, InvalidModelConfigException, MainDiffusersConfig, - ModelConfigBase, + Config_Base, ModelConfigFactory, get_model_discriminator_value, ) @@ -109,7 +109,7 @@ def test_probe_sd1_diffusers_inpainting(datadir: Path): assert config.repo_variant is ModelRepoVariant.FP16 -class MinimalConfigExample(ModelConfigBase): +class MinimalConfigExample(Config_Base): type: ModelType = ModelType.Main format: ModelFormat = ModelFormat.Checkpoint fun_quote: str @@ -175,10 +175,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): assert legacy_config.model_dump_json() == new_config.model_dump_json() elif legacy_config: - assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE + assert type(legacy_config) in Config_Base.USING_LEGACY_PROBE elif new_config: - assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API + assert type(new_config) in Config_Base.USING_CLASSIFY_API else: raise ValueError(f"Both probe and classify failed to classify model at path {path}.") @@ -186,7 +186,7 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): config_type = type(legacy_config or new_config) configs_with_tests.add(config_type) - untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample} + untested_configs = Config_Base.all_config_classes() - configs_with_tests - {MinimalConfigExample} logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}") @@ -206,7 +206,7 @@ def test_serialisation_roundtrip(): We need to ensure they are de-serialised into the original config with all relevant fields restored. """ excluded = {MinimalConfigExample} - for config_cls in ModelConfigBase.all_config_classes() - excluded: + for config_cls in Config_Base.all_config_classes() - excluded: trials_per_class = 50 configs_with_random_data = create_fake_configs(config_cls, trials_per_class) @@ -221,7 +221,7 @@ def test_serialisation_roundtrip(): def test_discriminator_tagging_for_config_instances(): """Verify that each ModelConfig instance is assigned the correct, unique Pydantic discriminator tag.""" excluded = {MinimalConfigExample} - config_classes = ModelConfigBase.all_config_classes() - excluded + config_classes = Config_Base.all_config_classes() - excluded tags = {c.get_tag() for c in config_classes} assert len(tags) == len(config_classes), "Each config should have its own unique tag" @@ -246,10 +246,10 @@ def test_inheritance_order(): It may be worth rethinking our config taxonomy in the future, but in the meantime this test can help prevent debugging effort. """ - for config_cls in ModelConfigBase.all_config_classes(): + for config_cls in Config_Base.all_config_classes(): excluded = {abc.ABC, pydantic.BaseModel, object} inheritance_list = [cls for cls in config_cls.mro() if cls not in excluded] - assert inheritance_list[-1] is ModelConfigBase + assert inheritance_list[-1] is Config_Base def test_any_model_config_includes_all_config_classes(): @@ -262,7 +262,7 @@ def test_any_model_config_includes_all_config_classes(): config_class, _ = get_args(annotated_pair) extracted.add(config_class) - expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample} + expected = set(Config_Base.all_config_classes()) - {MinimalConfigExample} assert extracted == expected @@ -270,7 +270,7 @@ def test_config_uniquely_matches_model(datadir: Path): model_paths = ModelSearch().search(datadir / "stripped_models") for path in model_paths: mod = StrippedModelOnDisk(path) - matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)} + matches = {cls for cls in Config_Base.USING_CLASSIFY_API if cls.matches(mod)} assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}" if not matches: logger.warning(f"Model at path {path} does not match any config classes using classify API.")