tidy(mm): consistent class names

This commit is contained in:
psychedelicious
2025-10-02 13:38:12 +10:00
parent e48e354bf1
commit edfd90f2a4
14 changed files with 137 additions and 139 deletions

View File

@@ -21,7 +21,7 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase
from invokeai.backend.model_manager.config import Main_Config_Base
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@@ -182,7 +182,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
assert isinstance(main_model_config, Main_Config_Base)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = dilated_mask_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)

View File

@@ -15,7 +15,7 @@ from invokeai.app.util.t5_model_identifier import (
)
from invokeai.backend.flux.util import get_flux_max_seq_length
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
Checkpoint_Config_Base,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@@ -87,7 +87,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
assert isinstance(transformer_config, Checkpoint_Config_Base)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),

View File

@@ -37,7 +37,7 @@ from invokeai.app.services.model_records import DuplicateModelException, ModelRe
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
Checkpoint_Config_Base,
InvalidModelConfigException,
ModelConfigFactory,
)
@@ -625,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
if isinstance(info, CheckpointConfigBase) and info.config_path is not None:
if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None:
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path

View File

@@ -21,7 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
Config_Base,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@@ -558,7 +558,7 @@ class ModelsInterface(InvocationContextInterface):
The absolute path to the model.
"""
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path)
if model_path.is_absolute():
return model_path.resolve()

View File

@@ -3,7 +3,7 @@
from invokeai.backend.model_manager.config import (
AnyModelConfig,
InvalidModelConfigException,
ModelConfigBase,
Config_Base,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
@@ -30,7 +30,7 @@ __all__ = [
"ModelConfigFactory",
"ModelProbe",
"ModelSearch",
"ModelConfigBase",
"Config_Base",
"AnyModel",
"AnyVariant",
"BaseModelType",

View File

@@ -35,7 +35,6 @@ from typing import (
Optional,
Self,
Type,
TypeAlias,
Union,
)
@@ -321,7 +320,7 @@ class LegacyProbeMixin:
pass
class ModelConfigBase(ABC, BaseModel):
class Config_Base(ABC, BaseModel):
"""
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
@@ -427,7 +426,7 @@ class ModelConfigBase(ABC, BaseModel):
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
if isinstance(v, ModelConfigBase):
if isinstance(v, Config_Base):
# We have an instance of a ModelConfigBase subclass - use its tag directly.
return v.get_tag().tag
if isinstance(v, dict):
@@ -473,7 +472,7 @@ class ModelConfigBase(ABC, BaseModel):
raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
class Unknown_Config(ModelConfigBase):
class Unknown_Config(Config_Base):
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
@@ -485,7 +484,7 @@ class Unknown_Config(ModelConfigBase):
raise NotAMatch(cls, "unknown model config cannot match any model")
class CheckpointConfigBase(ABC, BaseModel):
class Checkpoint_Config_Base(ABC, BaseModel):
"""Base class for checkpoint-style models."""
config_path: str | None = Field(
@@ -498,7 +497,7 @@ class CheckpointConfigBase(ABC, BaseModel):
)
class DiffusersConfigBase(ABC, BaseModel):
class Diffusers_Config_Base(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@@ -521,7 +520,7 @@ class DiffusersConfigBase(ABC, BaseModel):
return ModelRepoVariant.Default
class T5Encoder_T5Encoder_Config(ModelConfigBase):
class T5Encoder_T5Encoder_Config(Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
@@ -552,7 +551,7 @@ class T5Encoder_T5Encoder_Config(ModelConfigBase):
raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
class T5Encoder_BnBLLMint8_Config(Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
@@ -590,7 +589,7 @@ class T5Encoder_BnBLLMint8_Config(ModelConfigBase):
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
class LoRAConfigBase(ABC, BaseModel):
class LoRA_Config_Base(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
@@ -609,7 +608,7 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
return value
class LoRA_OMI_Config_Base(LoRAConfigBase):
class LoRA_OMI_Config_Base(LoRA_Config_Base):
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
@classmethod
@@ -663,15 +662,15 @@ class LoRA_OMI_Config_Base(LoRAConfigBase):
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, ModelConfigBase):
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, ModelConfigBase):
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
@@ -750,27 +749,27 @@ class LoRA_LyCORIS_Config_Base(LoRAConfigBase):
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase):
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlAdapterConfigBase(ABC, BaseModel):
class ControlAdapter_Config_Base(ABC, BaseModel):
default_settings: ControlAdapterDefaultSettings | None = Field(None)
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase):
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base):
"""Model config for Control LoRA models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -797,7 +796,7 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase)
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
class LoRA_Diffusers_Config_Base(LoRAConfigBase):
class LoRA_Diffusers_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Diffusers models."""
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
@@ -855,23 +854,23 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase):
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
@@ -925,23 +924,23 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
raise NotAMatch(cls, "cannot determine base type")
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Diffusers_Config_Base(DiffusersConfigBase):
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
@@ -1005,15 +1004,15 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase):
return BaseModelType.StableDiffusion1
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase):
class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
@@ -1068,23 +1067,23 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}")
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConfigBase):
class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
@@ -1161,19 +1160,19 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
raise NotAMatch(cls, "unable to determine base type from state dict")
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1267,15 +1266,15 @@ class TI_File_Config_Base(TI_Config_Base):
return cls(**fields)
class TI_File_SD1_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SD1_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_File_SD2_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SD2_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_File_SDXL_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -1298,19 +1297,19 @@ class TI_Folder_Config_Base(TI_Config_Base):
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
class TI_Folder_SD1_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_Folder_SD2_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class MainConfigBase(ABC, BaseModel):
class Main_Config_Base(ABC, BaseModel):
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
@@ -1352,7 +1351,7 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
return False
class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@@ -1450,19 +1449,19 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
raise NotAMatch(cls, "state dict does not look like a main model")
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
@@ -1511,7 +1510,7 @@ def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | Non
return FluxVariantType.Schnell
class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@@ -1581,7 +1580,7 @@ class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelCon
raise NotAMatch(cls, "state dict looks like GGUF quantized")
class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1630,7 +1629,7 @@ class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigB
raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1679,7 +1678,7 @@ class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBas
raise NotAMatch(cls, "state dict does not look like GGUF quantized")
class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@@ -1785,23 +1784,23 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
@classmethod
@@ -1875,7 +1874,7 @@ class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfig
return submodels
class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
@classmethod
@@ -1901,11 +1900,11 @@ class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelC
)
class IPAdapterConfigBase(ABC, BaseModel):
class IPAdapter_Config_Base(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter diffusers format models."""
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
@@ -1968,19 +1967,19 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@@ -2041,19 +2040,19 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -2071,7 +2070,7 @@ def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantTyp
return None
class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@@ -2119,15 +2118,15 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}")
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for CLIPVision."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@@ -2151,7 +2150,7 @@ class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
return cls(**fields)
class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase):
class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
@@ -2198,15 +2197,15 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class Spandrel_Checkpoint_Config(ModelConfigBase):
class Spandrel_Checkpoint_Config(Config_Base):
"""Model config for Spandrel Image to Image models."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@@ -2239,7 +2238,7 @@ class Spandrel_Checkpoint_Config(ModelConfigBase):
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
@@ -2263,7 +2262,7 @@ class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
return cls(**fields)
class FLUXRedux_Checkpoint_Config(ModelConfigBase):
class FLUXRedux_Checkpoint_Config(Config_Base):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
@@ -2282,7 +2281,7 @@ class FLUXRedux_Checkpoint_Config(ModelConfigBase):
return cls(**fields)
class LlavaOnevision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase):
class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
@@ -2316,23 +2315,23 @@ class ExternalAPI_Config_Base(ABC, BaseModel):
raise NotAMatch(cls, "External API models cannot be built from disk")
class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o)
class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5)
class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3)
class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4)
class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase):
class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
@@ -2344,11 +2343,11 @@ class VideoConfigBase(ABC, BaseModel):
)
class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase):
class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase):
class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
@@ -2447,11 +2446,10 @@ AnyModelConfig = Annotated[
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Discriminator(ModelConfigBase.get_model_discriminator_value),
Discriminator(Config_Base.get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory:
@@ -2459,7 +2457,7 @@ class ModelConfigFactory:
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data)
if isinstance(model, CheckpointConfigBase) and timestamp:
if isinstance(model, Checkpoint_Config_Base) and timestamp:
model.converted_at = timestamp
validate_hash(model.hash)
return model
@@ -2533,7 +2531,7 @@ class ModelConfigFactory:
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for config_class in ModelConfigBase.CONFIG_CLASSES:
for config_class in Config_Base.CONFIG_CLASSES:
class_name = config_class.__name__
try:
instance = config_class.from_model_on_disk(mod, fields)
@@ -2550,7 +2548,7 @@ class ModelConfigFactory:
results[class_name] = e
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
matches = [r for r in results.values() if isinstance(r, ModelConfigBase)]
matches = [r for r in results.values() if isinstance(r, Config_Base)]
if not matches and app_config.allow_unknown_models:
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
@@ -90,7 +90,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None,
)
# This needs to be implemented in the subclass

View File

@@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
Config_Base,
)
from invokeai.backend.model_manager.load import ModelLoaderBase
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
@@ -40,7 +40,7 @@ class ModelLoaderRegistryBase(ABC):
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@@ -84,7 +84,7 @@ class ModelLoaderRegistry(ModelLoaderRegistryBase):
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type

View File

@@ -5,7 +5,7 @@ from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_manager.config import (
AnyModelConfig,
DiffusersConfigBase,
Diffusers_Config_Base,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -21,7 +21,7 @@ class ClipVisionLoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, DiffusersConfigBase):
if not isinstance(config, Diffusers_Config_Base):
raise ValueError("Only DiffusersConfigBase models are currently supported here.")
if submodel_type is not None:

View File

@@ -5,8 +5,8 @@ import torch
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
Checkpoint_Config_Base,
Diffusers_Config_Base,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@@ -28,7 +28,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
if isinstance(config, Checkpoint_Config_Base):
raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.")
if submodel_type is None:
@@ -36,7 +36,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value

View File

@@ -36,7 +36,7 @@ from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
Checkpoint_Config_Base,
CLIPEmbed_Diffusers_Config_Base,
ControlNet_Checkpoint_Config_Base,
ControlNet_Diffusers_Config_Base,
@@ -211,7 +211,7 @@ class FluxCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
@@ -253,7 +253,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
@@ -299,7 +299,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:

View File

@@ -8,7 +8,7 @@ from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
@@ -33,7 +33,7 @@ class GenericDiffusersLoader(ModelLoader):
model_class = self.get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)

View File

@@ -13,8 +13,8 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpain
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
Checkpoint_Config_Base,
Diffusers_Config_Base,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
@@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
if isinstance(config, Checkpoint_Config_Base):
return self._load_from_singlefile(config, submodel_type)
if submodel_type is None:
@@ -73,7 +73,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
try:

View File

@@ -15,7 +15,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
InvalidModelConfigException,
MainDiffusersConfig,
ModelConfigBase,
Config_Base,
ModelConfigFactory,
get_model_discriminator_value,
)
@@ -109,7 +109,7 @@ def test_probe_sd1_diffusers_inpainting(datadir: Path):
assert config.repo_variant is ModelRepoVariant.FP16
class MinimalConfigExample(ModelConfigBase):
class MinimalConfigExample(Config_Base):
type: ModelType = ModelType.Main
format: ModelFormat = ModelFormat.Checkpoint
fun_quote: str
@@ -175,10 +175,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
assert legacy_config.model_dump_json() == new_config.model_dump_json()
elif legacy_config:
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
assert type(legacy_config) in Config_Base.USING_LEGACY_PROBE
elif new_config:
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
assert type(new_config) in Config_Base.USING_CLASSIFY_API
else:
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
@@ -186,7 +186,7 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
config_type = type(legacy_config or new_config)
configs_with_tests.add(config_type)
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
untested_configs = Config_Base.all_config_classes() - configs_with_tests - {MinimalConfigExample}
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
@@ -206,7 +206,7 @@ def test_serialisation_roundtrip():
We need to ensure they are de-serialised into the original config with all relevant fields restored.
"""
excluded = {MinimalConfigExample}
for config_cls in ModelConfigBase.all_config_classes() - excluded:
for config_cls in Config_Base.all_config_classes() - excluded:
trials_per_class = 50
configs_with_random_data = create_fake_configs(config_cls, trials_per_class)
@@ -221,7 +221,7 @@ def test_serialisation_roundtrip():
def test_discriminator_tagging_for_config_instances():
"""Verify that each ModelConfig instance is assigned the correct, unique Pydantic discriminator tag."""
excluded = {MinimalConfigExample}
config_classes = ModelConfigBase.all_config_classes() - excluded
config_classes = Config_Base.all_config_classes() - excluded
tags = {c.get_tag() for c in config_classes}
assert len(tags) == len(config_classes), "Each config should have its own unique tag"
@@ -246,10 +246,10 @@ def test_inheritance_order():
It may be worth rethinking our config taxonomy in the future, but in the meantime
this test can help prevent debugging effort.
"""
for config_cls in ModelConfigBase.all_config_classes():
for config_cls in Config_Base.all_config_classes():
excluded = {abc.ABC, pydantic.BaseModel, object}
inheritance_list = [cls for cls in config_cls.mro() if cls not in excluded]
assert inheritance_list[-1] is ModelConfigBase
assert inheritance_list[-1] is Config_Base
def test_any_model_config_includes_all_config_classes():
@@ -262,7 +262,7 @@ def test_any_model_config_includes_all_config_classes():
config_class, _ = get_args(annotated_pair)
extracted.add(config_class)
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
expected = set(Config_Base.all_config_classes()) - {MinimalConfigExample}
assert extracted == expected
@@ -270,7 +270,7 @@ def test_config_uniquely_matches_model(datadir: Path):
model_paths = ModelSearch().search(datadir / "stripped_models")
for path in model_paths:
mod = StrippedModelOnDisk(path)
matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)}
matches = {cls for cls in Config_Base.USING_CLASSIFY_API if cls.matches(mod)}
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
if not matches:
logger.warning(f"Model at path {path} does not match any config classes using classify API.")