diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index db5754ee2b..cfd166815d 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -17,8 +17,8 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( + IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, ) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType @@ -68,7 +68,7 @@ class FluxIPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig)) # Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy. image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model] diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 35a98ff6ba..5b99f72369 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( AnyModelConfig, + IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, ) from invokeai.backend.model_manager.starter_models import ( StarterModel, @@ -123,9 +123,9 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig)) - if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig): + if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base): image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() else: diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index c36aca067f..0b6e5fd83c 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -41,7 +41,7 @@ from typing import ( import torch from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError -from pydantic_core import CoreSchema, SchemaValidator +from pydantic_core import CoreSchema, PydanticUndefined, SchemaValidator from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config @@ -323,14 +323,17 @@ class LegacyProbeMixin: class ModelConfigBase(ABC, BaseModel): """ - Abstract Base class for model configurations. + Abstract base class for model configurations. A model config describes a specific combination of model base, type and + format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format + would have base=sd-1, type=main, format=checkpoint. To create a new config type, inherit from this class and implement its interface: - - (mandatory) override methods 'matches' and 'parse' - - (mandatory) define fields 'type' and 'format' as class attributes + - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be + called during model installation to determine the correct config class for a model. + - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A + default must be provided for each of these fields. - - (optional) override method 'get_tag' - - (optional) override field _MATCH_SPEED + If multiple combinations of base, type and format need to be supported, create a separate subclass for each. See MinimalConfigExample in test_model_probe.py for an example implementation. """ @@ -395,17 +398,23 @@ class ModelConfigBase(ABC, BaseModel): @classmethod def __pydantic_init_subclass__(cls, **kwargs): - # Ensure that subclasses define 'base', 'type' and 'format' fields. These are not in this base class, because - # subclasses may redefine them as different types, causing type-checking issues. - assert "base" in cls.model_fields, f"{cls.__name__} must define a 'base' field" - assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field" - assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field" + # Ensure that subclasses define 'base', 'type' and 'format' fields and provide defaults for them. Each subclass + # is expected to represent a single combination of base, type and format. + for name in ("type", "base", "format"): + assert name in cls.model_fields, f"{cls.__name__} must define a '{name}' field" + assert cls.model_fields[name].default is not PydanticUndefined, ( + f"{cls.__name__} must define a default for the '{name}' field" + ) @classmethod def get_tag(cls) -> Tag: - type = cls.model_fields["type"].default.value - format = cls.model_fields["format"].default.value - return Tag(f"{type}.{format}") + tag_strings: list[str] = [] + for name in ("type", "base", "format", "variant"): + if field := cls.model_fields.get(name): + if field.default is not PydanticUndefined: + # We assume each of these fields has an Enum for its default + tag_strings.append(str(field.default.value)) + return Tag(".".join(tag_strings)) @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: @@ -415,7 +424,9 @@ class ModelConfigBase(ABC, BaseModel): raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") -class UnknownModelConfig(ModelConfigBase): +class Unknown_Config(ModelConfigBase): + """Model config for unknown models, used as a fallback when we cannot identify a model.""" + base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) @@ -461,7 +472,7 @@ class DiffusersConfigBase(ABC, BaseModel): return ModelRepoVariant.Default -class T5EncoderConfig(ModelConfigBase): +class T5Encoder_T5Encoder_Config(ModelConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) @@ -492,7 +503,7 @@ class T5EncoderConfig(ModelConfigBase): raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") -class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase): +class T5Encoder_BnBLLMint8_Config(ModelConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) @@ -549,8 +560,7 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: return value -class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): - base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field() +class LoRA_OMI_Config_Base(LoRAConfigBase): format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) @classmethod @@ -559,24 +569,27 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - cls._validate_is_not_controllora_or_diffusers(mod) + cls._validate_looks_like_omi_lora(mod) - cls._validate_metadata_looks_like_omi(mod) + cls._validate_base(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) - - return cls(**fields, base=base) + return cls(**fields) @classmethod - def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA.""" + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" flux_format = _get_flux_lora_format(mod) if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") - @classmethod - def _validate_metadata_looks_like_omi(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" metadata = mod.metadata() metadata_looks_like_omi_lora = ( @@ -601,18 +614,17 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") -LoRALyCORIS_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] +class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): +class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class LoRA_LyCORIS_Config_Base(LoRAConfigBase): """Model config for LoRA/Lycoris models.""" - base: LoRALyCORIS_SupportedBases = Field() type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) @@ -622,14 +634,27 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): _validate_override_fields(cls, fields) - cls._validate_is_not_controllora_or_diffusers(mod) - cls._validate_looks_like_lora(mod) + cls._validate_base(mod) + return cls(**fields) + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + @classmethod def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + # First rule out ControlLoRA and Diffusers LoRA + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. # Some main models have these keys, likely due to the creator merging in a LoRA. has_key_with_lora_prefix = has_keys_starting_with( @@ -657,14 +682,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") @classmethod - def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA.""" - flux_format = _get_flux_lora_format(mod) - if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: if _get_flux_lora_format(mod): return BaseModelType.Flux @@ -683,17 +701,30 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") +class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + class ControlAdapterConfigBase(ABC, BaseModel): default_settings: ControlAdapterDefaultSettings | None = Field(None) -ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux] - - -class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase): +class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase): """Model config for Control LoRA models.""" - base: ControlLoRALyCORIS_SupportedBases = Field() + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) @@ -717,66 +748,53 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") -ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux] +# LoRADiffusers_SupportedBases: TypeAlias = Literal[ +# BaseModelType.StableDiffusion1, +# BaseModelType.StableDiffusion2, +# BaseModelType.StableDiffusionXL, +# BaseModelType.Flux, +# ] -LoRADiffusers_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] +# class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): +# """Model config for LoRA/Diffusers models.""" + +# # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates +# # the weights format. FLUX Diffusers LoRAs are single files. + +# base: LoRADiffusers_SupportedBases = Field() +# format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + +# @classmethod +# def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: +# _validate_is_dir(cls, mod) + +# _validate_override_fields(cls, fields) + +# cls._validate_looks_like_diffusers_lora(mod) + +# return cls(**fields) + +# @classmethod +# def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None: +# suffixes = ["bin", "safetensors"] +# weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] +# has_lora_weight_file = any(wf.exists() for wf in weight_files) +# if not has_lora_weight_file: +# raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") + +# flux_lora_format = _get_flux_lora_format(mod) +# if flux_lora_format is not FluxLoRAFormat.Diffusers: +# raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA") -class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): - """Model config for LoRA/Diffusers models.""" - - base: LoRADiffusers_SupportedBases = Field() - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_diffusers_lora(mod) - - cls._validate_has_lora_weight_file(mod) - - return cls(**fields) - - @classmethod - def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None: - flux_lora_format = _get_flux_lora_format(mod) - if flux_lora_format is not FluxLoRAFormat.Diffusers: - raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA") - - @classmethod - def _validate_has_lora_weight_file(cls, mod: ModelOnDisk) -> None: - suffixes = ["bin", "safetensors"] - weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - has_lora_weight_file = any(wf.exists() for wf in weight_files) - if not has_lora_weight_file: - raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") - - -VAECheckpointConfig_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] - - -class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): +class VAE_Checkpoint_Config_Base(CheckpointConfigBase): """Model config for standalone VAE models.""" - base: VAECheckpointConfig_SupportedBases = Field() type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = { + REGEX_TO_BASE: ClassVar[dict[str, BaseModelType]] = { r"xl": BaseModelType.StableDiffusionXL, r"sd2": BaseModelType.StableDiffusion2, r"vae": BaseModelType.StableDiffusion1, @@ -791,8 +809,17 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): cls._validate_looks_like_vae(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: @@ -806,7 +833,7 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name for regexp, base in cls.REGEX_TO_BASE.items(): if re.search(regexp, mod.path.name, re.IGNORECASE): @@ -815,16 +842,25 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase): raise NotAMatch(cls, "cannot determine base type") -VAEDiffusersConfig_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusionXL, -] +class VAE_SD1_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): +class VAE_SD2_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class VAE_SDXL_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class VAE_FLUX_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class VAE_Diffusers_Config_Base(DiffusersConfigBase): """Model config for standalone VAE models (diffusers version).""" - base: VAEDiffusersConfig_SupportedBases = Field() type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -843,8 +879,17 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): }, ) - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: @@ -866,7 +911,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): return name @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: config = _get_config_or_raise(cls, common_config_paths(mod.path)) if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL @@ -877,18 +922,17 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase): return BaseModelType.StableDiffusion1 -ControlNetDiffusers_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] +class VAE_SD1_Diffusers_Config(VAE_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase): +class VAE_SDXL_Diffusers_Config(VAE_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase): """Model config for ControlNet models (diffusers version).""" - base: ControlNetDiffusers_SupportedBases = Field() type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -907,12 +951,20 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M }, ) - base = fields.get("base") or cls._get_base_or_raise(mod) + cls._validate_base(mod) - return cls(**fields, base=base) + return cls(**fields) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases: + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: config = _get_config_or_raise(cls, common_config_paths(mod.path)) if config.get("_class_name") == "FluxControlNetModel": @@ -933,6 +985,22 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}") +class ControlNet_SD1_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class ControlNet_SD2_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_SDXL_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_FLUX_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, @@ -941,10 +1009,9 @@ ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[ ] -class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ModelConfigBase): +class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConfigBase): """Model config for ControlNet models (diffusers version).""" - base: ControlNetDiffusers_SupportedBases = Field() type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @@ -956,9 +1023,17 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, cls._validate_looks_like_controlnet(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) + cls._validate_base(mod) - return cls(**fields, base=base) + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: @@ -1011,17 +1086,33 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, raise NotAMatch(cls, "unable to determine base type from state dict") -TextualInversion_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, -] +class ControlNet_SD1_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class TextualInversionConfigBase(ABC, BaseModel): - base: TextualInversion_SupportedBases = Field() +class ControlNet_SD2_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_SDXL_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_FLUX_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class TI_Config_Base(ABC, BaseModel): type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) + @classmethod + def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod, path) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + @classmethod def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: try: @@ -1051,7 +1142,7 @@ class TextualInversionConfigBase(ABC, BaseModel): return False @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> TextualInversion_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: p = path or mod.path try: @@ -1082,7 +1173,7 @@ class TextualInversionConfigBase(ABC, BaseModel): raise NotAMatch(cls, f"unrecognized token dimension {token_dim}") -class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): +class TI_File_Config_Base(TI_Config_Base): """Model config for textual inversion embeddings.""" format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) @@ -1096,11 +1187,24 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): if not cls._file_looks_like_embedding(mod): raise NotAMatch(cls, "model does not look like a textual inversion embedding file") - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) + cls._validate_base(mod) + + return cls(**fields) -class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): +class TI_SD1_File_Config(TI_File_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_SD2_File_Config(TI_File_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_SDXL_File_Config(TI_File_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class TI_Folder_Config_Base(TI_Config_Base): """Model config for textual inversion embeddings.""" format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) @@ -1113,12 +1217,24 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): for p in mod.weight_files(): if cls._file_looks_like_embedding(mod, p): - base = fields.get("base") or cls._get_base_or_raise(mod, p) - return cls(**fields, base=base) + cls._validate_base(mod, p) + return cls(**fields) raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") +class TI_SD1_Folder_Config(TI_Folder_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_SD2_Folder_Config(TI_Folder_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_SDXL_Folder_Config(TI_Folder_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + class MainConfigBase(ABC, BaseModel): type: Literal[ModelType.Main] = Field(default=ModelType.Main) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) @@ -1161,20 +1277,11 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: return False -SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.StableDiffusionXLRefiner, -] - - -class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase): """Model config for main checkpoint models.""" format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases = Field() prediction_type: SchedulerPredictionType = Field() variant: ModelVariantType = Field() @@ -1186,14 +1293,24 @@ class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, cls._validate_looks_like_main_model(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) - prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) - variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) + cls._validate_base(mod) - return cls(**fields, base=base, prediction_type=prediction_type, variant=variant) + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, prediction_type=prediction_type, variant=variant) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases: + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: state_dict = mod.load_state_dict() key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -1211,9 +1328,9 @@ class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, raise NotAMatch(cls, "unable to determine base type from state dict") @classmethod - def _get_scheduler_prediction_type_or_raise( - cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases - ) -> SchedulerPredictionType: + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + base = cls.model_fields["base"].default.value + if base is BaseModelType.StableDiffusion2: state_dict = mod.load_state_dict() key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -1228,9 +1345,9 @@ class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, return SchedulerPredictionType.Epsilon @classmethod - def _get_variant_or_raise( - cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_CheckpointConfig_SupportedBases - ) -> ModelVariantType: + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default.value + state_dict = mod.load_state_dict() key_name = "model.diffusion_model.input_blocks.0.0.weight" @@ -1258,6 +1375,22 @@ class SD_1_2_XL_XLRefiner_CheckpointConfig(CheckpointConfigBase, MainConfigBase, raise NotAMatch(cls, "state dict does not look like a main model") +class Main_SD1_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class Main_SD2_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class Main_SDXL_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class Main_SDXLRefiner_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) + + def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: # FLUX Model variant types are distinguished by input channels and the presence of certain keys. @@ -1303,7 +1436,7 @@ def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | Non return FluxVariantType.Schnell -class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_FLUX_Checkpoint_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @@ -1373,11 +1506,12 @@ class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, Mo raise NotAMatch(cls, "state dict looks like GGUF quantized") -class FLUX_Quantized_BnB_NF4_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_FLUX_BnBNF4_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) + variant: FluxVariantType = Field() @classmethod @@ -1421,11 +1555,12 @@ class FLUX_Quantized_BnB_NF4_CheckpointConfig(CheckpointConfigBase, MainConfigBa raise NotAMatch(cls, "state dict does not look like bnb quantized nf4") -class FLUX_Quantized_GGUF_CheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase): +class Main_FLUX_GGUF_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase): """Model config for main checkpoint models.""" base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + variant: FluxVariantType = Field() @classmethod @@ -1469,16 +1604,7 @@ class FLUX_Quantized_GGUF_CheckpointConfig(CheckpointConfigBase, MainConfigBase, raise NotAMatch(cls, "state dict does not look like GGUF quantized") -SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.StableDiffusionXLRefiner, -] - - -class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): - base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases = Field() +class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase): prediction_type: SchedulerPredictionType = Field() variant: ModelVariantType = Field() @@ -1505,24 +1631,31 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M }, ) - base = fields.get("base") or cls._get_base_or_raise(mod) + cls._validate_base(mod) - variant = fields.get("variant") or cls._get_variant_or_raise(mod, base) + variant = fields.get("variant") or cls._get_variant_or_raise(mod) - prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base) + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) return cls( **fields, - base=base, variant=variant, prediction_type=prediction_type, repo_variant=repo_variant, ) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases: + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). unet_config_path = mod.path / "unet" / "config.json" if unet_config_path.exists(): @@ -1544,9 +1677,7 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M raise NotAMatch(cls, "unable to determine base type") @classmethod - def _get_scheduler_prediction_type_or_raise( - cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases - ) -> SchedulerPredictionType: + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") # TODO(psyche): Is epsilon the right default or should we raise if it's not present? @@ -1561,9 +1692,8 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}") @classmethod - def _get_variant_or_raise( - cls, mod: ModelOnDisk, base: SD_1_2_XL_XLRefiner_DiffusersConfig_SupportedBases - ) -> ModelVariantType: + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default.value unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") in_channels = unet_config.get("in_channels") @@ -1580,7 +1710,23 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") -class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): +class Main_SD1_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) + + +class Main_SD2_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) + + +class Main_SDXL_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) + + +class Main_SDXLRefiner_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) + + +class Main_SD3_Diffusers_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase): base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) @classmethod @@ -1589,6 +1735,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase) _validate_override_fields(cls, fields) + # This check implies the base type - no further validation needed. _validate_class_name( cls, common_config_paths(mod.path), @@ -1604,7 +1751,6 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase) return cls( **fields, - base=BaseModelType.StableDiffusion3, submodels=submodels, repo_variant=repo_variant, ) @@ -1654,7 +1800,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase) return submodels -class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase): +class Main_CogView4_Diffusers_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase): base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) @classmethod @@ -1663,6 +1809,7 @@ class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigB _validate_override_fields(cls, fields) + # This check implies the base type - no further validation needed. _validate_class_name( cls, common_config_paths(mod.path), @@ -1683,17 +1830,9 @@ class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) -IPAdapterInvokeAI_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, -] - - -class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): +class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase): """Model config for IP Adapter diffusers format models.""" - base: IPAdapterInvokeAI_SupportedBases = Field() format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long @@ -1710,8 +1849,17 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): cls._validate_has_image_encoder_metadata_file(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: @@ -1726,7 +1874,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, "missing image_encoder.txt metadata file") @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: state_dict = mod.load_state_dict() try: @@ -1745,18 +1893,21 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") -IPAdapterCheckpoint_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] +class IPAdapter_SD1_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): +class IPAdapter_SD2_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_SDXL_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase): """Model config for IP Adapter checkpoint format models.""" - base: IPAdapterCheckpoint_SupportedBases = Field() format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) @classmethod @@ -1767,8 +1918,17 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): cls._validate_looks_like_ip_adapter(mod) - base = fields.get("base") or cls._get_base_or_raise(mod) - return cls(**fields, base=base) + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: @@ -1784,7 +1944,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: state_dict = mod.load_state_dict() if is_state_dict_xlabs_ip_adapter(state_dict): @@ -1806,6 +1966,22 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") +class IPAdapter_SD1_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class IPAdapter_SD2_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_SDXL_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_FLUX_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: try: hidden_size = config.get("hidden_size") @@ -1820,91 +1996,54 @@ def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantTyp return None -class CLIPEmbedDiffusersConfig(DiffusersConfigBase): - """Model config for Clip Embeddings.""" - +class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase): base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) -class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): - """Model config for CLIP-G Embeddings.""" + _validate_override_fields(cls, fields) + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, + ) + + cls._validate_variant(mod) + + return cls(**fields) + + @classmethod + def _validate_variant(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model variant does not match this config class.""" + expected_variant = cls.model_fields["variant"].default.value + config = _get_config_or_raise(cls, common_config_paths(mod.path)) + recognized_variant = _get_clip_variant_type_from_config(config) + + if recognized_variant is None: + raise NotAMatch(cls, "unable to determine CLIP variant from config") + + if expected_variant is not recognized_variant: + raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}") + + +class CLIPEmbed_G_Diffusers_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase): variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "CLIPModel", - "CLIPTextModel", - "CLIPTextModelWithProjection", - }, - ) - - cls._validate_clip_g_variant(mod) - - return cls(**fields) - - @classmethod - def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - clip_variant = _get_clip_variant_type_from_config(config) - - if clip_variant is not ClipVariantType.G: - raise NotAMatch(cls, "model does not match CLIP-G heuristics") - - -class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): - """Model config for CLIP-L Embeddings.""" +class CLIPEmbed_L_Diffusers_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase): variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "CLIPModel", - "CLIPTextModel", - "CLIPTextModelWithProjection", - }, - ) - - cls._validate_clip_l_variant(mod) - - return cls(**fields) - - @classmethod - def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None: - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - clip_variant = _get_clip_variant_type_from_config(config) - - if clip_variant is not ClipVariantType.L: - raise NotAMatch(cls, "model does not match CLIP-G heuristics") - - -class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): +class CLIPVision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): """Model config for CLIPVision.""" base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) @@ -1928,16 +2067,9 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -T2IAdapterDiffusers_SupportedBases: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusionXL, -] - - -class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase): +class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfigBase): """Model config for T2I.""" - base: T2IAdapterDiffusers_SupportedBases = Field() type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) @@ -1955,12 +2087,20 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi }, ) - base = fields.get("base") or cls._get_base_or_raise(mod) + cls._validate_base(mod) - return cls(**fields, base=base) + return cls(**fields) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases: + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: config = _get_config_or_raise(cls, common_config_paths(mod.path)) adapter_type = config.get("adapter_type") @@ -1974,7 +2114,15 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'") -class SpandrelImageToImageConfig(ModelConfigBase): +class T2IAdapter_SD1_Diffusers_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class T2IAdapter_SDXL_Diffusers_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class Spandrel_Checkpoint_Config(ModelConfigBase): """Model config for Spandrel Image to Image models.""" base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) @@ -2007,7 +2155,7 @@ class SpandrelImageToImageConfig(ModelConfigBase): raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e -class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): +class SigLIP_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): """Model config for SigLIP.""" type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) @@ -2031,7 +2179,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -class FluxReduxConfig(ModelConfigBase): +class FLUXRedux_Checkpoint_Config(ModelConfigBase): """Model config for FLUX Tools Redux model.""" type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) @@ -2050,7 +2198,7 @@ class FluxReduxConfig(ModelConfigBase): return cls(**fields) -class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): +class LlavaOnevision_Diffusers_Config(DiffusersConfigBase, ModelConfigBase): """Model config for Llava Onevision models.""" type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) @@ -2074,48 +2222,50 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): return cls(**fields) -ApiModel_SupportedBases: TypeAlias = Literal[ - BaseModelType.ChatGPT4o, - BaseModelType.Gemini2_5, - BaseModelType.Imagen3, - BaseModelType.Imagen4, - BaseModelType.FluxKontext, -] - - -class ApiModelConfig(MainConfigBase, ModelConfigBase): +class ExternalAPI_Config_Base(ABC, BaseModel): """Model config for API-based models.""" - type: Literal[ModelType.Main] = Field(default=ModelType.Main) format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) - base: ApiModel_SupportedBases = Field() @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - raise NotAMatch(cls, "API models cannot be built from disk") + raise NotAMatch(cls, "External API models cannot be built from disk") -VideoApiModel_SupportedBases: TypeAlias = Literal[ - BaseModelType.Veo3, - BaseModelType.Runway, -] +class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) -class VideoApiModelConfig(ModelConfigBase): - """Model config for API-based video models.""" +class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) + +class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) + + +class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) + + +class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, MainConfigBase, ModelConfigBase): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class VideoConfigBase(ABC, BaseModel): type: Literal[ModelType.Video] = Field(default=ModelType.Video) - base: VideoApiModel_SupportedBases = Field() - format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) - trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None) default_settings: MainModelDefaultSettings | None = Field( description="Default settings for this model", default=None ) - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - raise NotAMatch(cls, "API models cannot be built from disk") + +class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, ModelConfigBase): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) def get_model_discriminator_value(v: Any) -> str: @@ -2123,72 +2273,109 @@ def get_model_discriminator_value(v: Any) -> str: Computes the discriminator value for a model config. https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator """ - format_ = type_ = variant_ = None + if isinstance(v, ModelConfigBase): + return v.get_tag().tag - if isinstance(v, dict): - format_ = v.get("format") - if isinstance(format_, Enum): - format_ = format_.value - - type_ = v.get("type") - if isinstance(type_, Enum): - type_ = type_.value - - variant_ = v.get("variant") - if isinstance(variant_, Enum): - variant_ = variant_.value - else: - format_ = v.format.value - type_ = v.type.value - variant_ = getattr(v, "variant", None) - if variant_: - variant_ = variant_.value - - # Ideally, each config would be uniquely identified with a combination of fields - # i.e. (type, format, variant) without any special cases. Alas... - - # Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field. - # To maintain compatibility, we default to ClipVariantType.L in this case. - if type_ == ModelType.CLIPEmbed.value: - return f"{type_}.{format_}.{variant_}" - return f"{type_}.{format_}" + tag_strings: list[str] = [] + for name in ("type", "base", "format", "variant"): + field_value = v.get(name) + if isinstance(field_value, Enum): + field_value = field_value.value + if field_value is not None: + tag_strings.append(field_value) + return ".".join(tag_strings) # The types are listed explicitly because IDEs/LSPs can't identify the correct types # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes AnyModelConfig = Annotated[ Union[ - Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()], - Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()], - Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()], - Annotated[SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_DiffusersConfig.get_tag()], - Annotated[SD_3_DiffusersConfig, SD_3_DiffusersConfig.get_tag()], - Annotated[CogView4_DiffusersConfig, CogView4_DiffusersConfig.get_tag()], - Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], - Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], - Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], - Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], - Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], - Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()], - Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()], - Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], - Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], - Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()], - Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], - Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], - Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], - Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], - Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], - Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], - Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], - Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()], - Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()], - Annotated[SigLIPConfig, SigLIPConfig.get_tag()], - Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()], - Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()], - Annotated[ApiModelConfig, ApiModelConfig.get_tag()], - Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()], - Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()], + # Main (Pipeline) - diffusers format + Annotated[Main_SD1_Diffusers_Config, Main_SD1_Diffusers_Config.get_tag()], + Annotated[Main_SD2_Diffusers_Config, Main_SD2_Diffusers_Config.get_tag()], + Annotated[Main_SDXL_Diffusers_Config, Main_SDXL_Diffusers_Config.get_tag()], + Annotated[Main_SDXLRefiner_Diffusers_Config, Main_SDXLRefiner_Diffusers_Config.get_tag()], + Annotated[Main_SD3_Diffusers_Config, Main_SD3_Diffusers_Config.get_tag()], + Annotated[Main_CogView4_Diffusers_Config, Main_CogView4_Diffusers_Config.get_tag()], + # Main (Pipeline) - checkpoint format + Annotated[Main_SD1_Checkpoint_Config, Main_SD1_Checkpoint_Config.get_tag()], + Annotated[Main_SD2_Checkpoint_Config, Main_SD2_Checkpoint_Config.get_tag()], + Annotated[Main_SDXL_Checkpoint_Config, Main_SDXL_Checkpoint_Config.get_tag()], + Annotated[Main_SDXLRefiner_Checkpoint_Config, Main_SDXLRefiner_Checkpoint_Config.get_tag()], + Annotated[Main_FLUX_Checkpoint_Config, Main_FLUX_Checkpoint_Config.get_tag()], + # Main (Pipeline) - quantized formats + Annotated[Main_FLUX_BnBNF4_Config, Main_FLUX_BnBNF4_Config.get_tag()], + Annotated[Main_FLUX_GGUF_Config, Main_FLUX_GGUF_Config.get_tag()], + # VAE - checkpoint format + Annotated[VAE_SD1_Checkpoint_Config, VAE_SD1_Checkpoint_Config.get_tag()], + Annotated[VAE_SD2_Checkpoint_Config, VAE_SD2_Checkpoint_Config.get_tag()], + Annotated[VAE_SDXL_Checkpoint_Config, VAE_SDXL_Checkpoint_Config.get_tag()], + Annotated[VAE_FLUX_Checkpoint_Config, VAE_FLUX_Checkpoint_Config.get_tag()], + # VAE - diffusers format + Annotated[VAE_SD1_Diffusers_Config, VAE_SD1_Diffusers_Config.get_tag()], + Annotated[VAE_SDXL_Diffusers_Config, VAE_SDXL_Diffusers_Config.get_tag()], + # ControlNet - checkpoint format + Annotated[ControlNet_SD1_Checkpoint_Config, ControlNet_SD1_Checkpoint_Config.get_tag()], + Annotated[ControlNet_SD2_Checkpoint_Config, ControlNet_SD2_Checkpoint_Config.get_tag()], + Annotated[ControlNet_SDXL_Checkpoint_Config, ControlNet_SDXL_Checkpoint_Config.get_tag()], + Annotated[ControlNet_FLUX_Checkpoint_Config, ControlNet_FLUX_Checkpoint_Config.get_tag()], + # ControlNet - diffusers format + Annotated[ControlNet_SD1_Diffusers_Config, ControlNet_SD1_Diffusers_Config.get_tag()], + Annotated[ControlNet_SD2_Diffusers_Config, ControlNet_SD2_Diffusers_Config.get_tag()], + Annotated[ControlNet_SDXL_Diffusers_Config, ControlNet_SDXL_Diffusers_Config.get_tag()], + Annotated[ControlNet_FLUX_Diffusers_Config, ControlNet_FLUX_Diffusers_Config.get_tag()], + # LoRA - LyCORIS format + Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()], + Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()], + Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()], + Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], + # LoRA - OMI format + Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], + Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], + # LoRA - diffusers format (TODO) + # Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], + # ControlLoRA - diffusers format + Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], + Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], + Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()], + # TI - file format + Annotated[TI_SD1_File_Config, TI_SD1_File_Config.get_tag()], + Annotated[TI_SD2_File_Config, TI_SD2_File_Config.get_tag()], + Annotated[TI_SDXL_File_Config, TI_SDXL_File_Config.get_tag()], + # TI - folder format + Annotated[TI_SD1_Folder_Config, TI_SD1_Folder_Config.get_tag()], + Annotated[TI_SD2_Folder_Config, TI_SD2_Folder_Config.get_tag()], + Annotated[TI_SDXL_Folder_Config, TI_SDXL_Folder_Config.get_tag()], + # IP Adapter - InvokeAI format + Annotated[IPAdapter_SD1_InvokeAI_Config, IPAdapter_SD1_InvokeAI_Config.get_tag()], + Annotated[IPAdapter_SD2_InvokeAI_Config, IPAdapter_SD2_InvokeAI_Config.get_tag()], + Annotated[IPAdapter_SDXL_InvokeAI_Config, IPAdapter_SDXL_InvokeAI_Config.get_tag()], + # IP Adapter - checkpoint format + Annotated[IPAdapter_SD1_Checkpoint_Config, IPAdapter_SD1_Checkpoint_Config.get_tag()], + Annotated[IPAdapter_SD2_Checkpoint_Config, IPAdapter_SD2_Checkpoint_Config.get_tag()], + Annotated[IPAdapter_SDXL_Checkpoint_Config, IPAdapter_SDXL_Checkpoint_Config.get_tag()], + Annotated[IPAdapter_FLUX_Checkpoint_Config, IPAdapter_FLUX_Checkpoint_Config.get_tag()], + # T2I Adapter - diffusers format + Annotated[T2IAdapter_SD1_Diffusers_Config, T2IAdapter_SD1_Diffusers_Config.get_tag()], + Annotated[T2IAdapter_SDXL_Diffusers_Config, T2IAdapter_SDXL_Diffusers_Config.get_tag()], + # Misc models + Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()], + Annotated[CLIPEmbed_G_Diffusers_Config, CLIPEmbed_G_Diffusers_Config.get_tag()], + Annotated[CLIPEmbed_L_Diffusers_Config, CLIPEmbed_L_Diffusers_Config.get_tag()], + Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()], + Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], + Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], + Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + # API models + Annotated[ExternalAPI_ChatGPT4o_Config, ExternalAPI_ChatGPT4o_Config.get_tag()], + Annotated[ExternalAPI_Gemini2_5_Config, ExternalAPI_Gemini2_5_Config.get_tag()], + Annotated[ExternalAPI_Imagen3_Config, ExternalAPI_Imagen3_Config.get_tag()], + Annotated[ExternalAPI_Imagen4_Config, ExternalAPI_Imagen4_Config.get_tag()], + Annotated[ExternalAPI_FluxKontext_Config, ExternalAPI_FluxKontext_Config.get_tag()], + Annotated[ExternalAPI_Veo3_Config, ExternalAPI_Veo3_Config.get_tag()], + Annotated[ExternalAPI_Runway_Config, ExternalAPI_Runway_Config.get_tag()], + # Unknown model (fallback) + Annotated[Unknown_Config, Unknown_Config.get_tag()], ], Discriminator(get_model_discriminator_value), ] @@ -2298,7 +2485,7 @@ class ModelConfigFactory: if not matches and app_config.allow_unknown_models: logger.warning(f"Unable to identify model {mod.name}, classifying as UnknownModelConfig") logger.debug(f"Model matching results: {results}") - return UnknownModelConfig(**fields) + return Unknown_Config(**fields) instance = next(iter(matches)) if len(matches) > 1: diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 458fc0cfc0..75191517c7 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -12,9 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple import torch from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.config import AnyModelConfig from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 5bf93db381..62a8ed4f65 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -7,7 +7,7 @@ from diffusers import ControlNetModel from invokeai.backend.model_manager.config import ( AnyModelConfig, - ControlNetCheckpointConfig, + ControlNet_Checkpoint_Config_Base, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader @@ -46,7 +46,7 @@ class ControlNetLoader(GenericDiffusersLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): + if isinstance(config, ControlNet_Checkpoint_Config_Base): return ControlNetModel.from_single_file( config.path, torch_dtype=self._torch_dtype, diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 570069632a..4bb24b7466 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -37,26 +37,26 @@ from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, - CLIPEmbedDiffusersConfig, - ControlNetCheckpointConfig, - ControlNetDiffusersConfig, - FLUX_Quantized_BnB_NF4_CheckpointConfig, - FLUX_Quantized_GGUF_CheckpointConfig, - FLUX_Unquantized_CheckpointConfig, - FluxReduxConfig, - IPAdapterCheckpointConfig, - T5EncoderBnbQuantizedLlmInt8bConfig, - T5EncoderConfig, - VAECheckpointConfig, + CLIPEmbed_Diffusers_Config_Base, + ControlNet_Checkpoint_Config_Base, + ControlNet_Diffusers_Config_Base, + FLUXRedux_Checkpoint_Config, + IPAdapter_Checkpoint_Config_Base, + Main_FLUX_BnBNF4_Config, + Main_FLUX_Checkpoint_Config, + Main_FLUX_GGUF_Config, + T5Encoder_BnBLLMint8_Config, + T5Encoder_T5Encoder_Config, + VAE_Checkpoint_Config_Base, ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( AnyModel, BaseModelType, + FluxVariantType, ModelFormat, ModelType, - ModelVariantType, SubModelType, ) from invokeai.backend.model_manager.util.model_util import ( @@ -86,7 +86,7 @@ class FluxVAELoader(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, VAECheckpointConfig): + if not isinstance(config, VAE_Checkpoint_Config_Base): raise ValueError("Only VAECheckpointConfig models are currently supported here.") model_path = Path(config.path) @@ -116,7 +116,7 @@ class CLIPDiffusersLoader(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CLIPEmbedDiffusersConfig): + if not isinstance(config, CLIPEmbed_Diffusers_Config_Base): raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.") match submodel_type: @@ -139,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig): + if not isinstance(config, T5Encoder_BnBLLMint8_Config): raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.") if not bnb_available: raise ImportError( @@ -186,7 +186,7 @@ class T5EncoderCheckpointModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, T5EncoderConfig): + if not isinstance(config, T5Encoder_T5Encoder_Config): raise ValueError("Only T5EncoderConfig models are currently supported here.") match submodel_type: @@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, FLUX_Unquantized_CheckpointConfig) + assert isinstance(config, Main_FLUX_Checkpoint_Config) model_path = Path(config.path) with accelerate.init_empty_weights(): @@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig) + assert isinstance(config, Main_FLUX_GGUF_Config) model_path = Path(config.path) with accelerate.init_empty_weights(): @@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig) + assert isinstance(config, Main_FLUX_BnBNF4_Config) if not bnb_available: raise ImportError( "The bnb modules are not available. Please install bitsandbytes if available on your platform." @@ -342,9 +342,9 @@ class FluxControlnetModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): + if isinstance(config, ControlNet_Checkpoint_Config_Base): model_path = Path(config.path) - elif isinstance(config, ControlNetDiffusersConfig): + elif isinstance(config, ControlNet_Diffusers_Config_Base): # If this is a diffusers directory, we simply ignore the config file and load from the weight file. model_path = Path(config.path) / "diffusion_pytorch_model.safetensors" else: @@ -363,7 +363,7 @@ class FluxControlnetModel(ModelLoader): def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: with accelerate.init_empty_weights(): # HACK(ryand): Is it safe to assume dev here? - model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev)) + model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev)) model.load_state_dict(sd, assign=True) return model @@ -389,7 +389,7 @@ class FluxIpAdapterModel(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, IPAdapterCheckpointConfig): + if not isinstance(config, IPAdapter_Checkpoint_Config_Base): raise ValueError(f"Unexpected model config type: {type(config)}.") sd = load_file(Path(config.path)) @@ -412,7 +412,7 @@ class FluxReduxModelLoader(ModelLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, FluxReduxConfig): + if not isinstance(config, FLUXRedux_Checkpoint_Config): raise ValueError(f"Unexpected model config type: {type(config)}.") sd = load_file(Path(config.path)) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 9d771feae7..ab7982394b 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -15,8 +15,14 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, DiffusersConfigBase, - SD_1_2_XL_XLRefiner_CheckpointConfig, - SD_1_2_XL_XLRefiner_DiffusersConfig, + Main_SD1_Checkpoint_Config, + Main_SD1_Diffusers_Config, + Main_SD2_Checkpoint_Config, + Main_SD2_Diffusers_Config, + Main_SDXL_Checkpoint_Config, + Main_SDXL_Diffusers_Config, + Main_SDXLRefiner_Checkpoint_Config, + Main_SDXLRefiner_Diffusers_Config, ) from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry @@ -108,7 +114,19 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): ModelVariantType.Normal: StableDiffusionXLPipeline, }, } - assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig)) + assert isinstance( + config, + ( + Main_SD1_Diffusers_Config, + Main_SD2_Diffusers_Config, + Main_SDXL_Diffusers_Config, + Main_SDXLRefiner_Diffusers_Config, + Main_SD1_Checkpoint_Config, + Main_SD2_Checkpoint_Config, + Main_SDXL_Checkpoint_Config, + Main_SDXLRefiner_Checkpoint_Config, + ), + ) try: load_class = load_classes[config.base][config.variant] except KeyError as e: diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 365fa0a547..12789e58c2 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -3,9 +3,9 @@ from typing import Optional -from diffusers import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig +from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( @@ -27,7 +27,7 @@ class VAELoader(GenericDiffusersLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, VAECheckpointConfig): + if isinstance(config, VAE_Checkpoint_Config_Base): return AutoencoderKL.from_single_file( config.path, torch_dtype=self._torch_dtype, diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index c8b5698dd8..41bbc2c024 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import ( ControlAdapterDefaultSettings, MainDiffusersConfig, MainModelDefaultSettings, - TextualInversionFileConfig, + TI_File_Config, VAEDiffusersConfig, ) from invokeai.backend.model_manager.taxonomy import ModelSourceType @@ -40,8 +40,8 @@ def store( return ModelRecordServiceSQL(db, logger) -def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig: - config = TextualInversionFileConfig( +def example_ti_config(key: Optional[str] = None) -> TI_File_Config: + config = TI_File_Config( source="test/source/", source_type=ModelSourceType.Path, path="/tmp/pokemon.bin", @@ -61,7 +61,7 @@ def test_type(store: ModelRecordServiceBase): config = example_ti_config("key1") store.add_model(config) config1 = store.get_model("key1") - assert isinstance(config1, TextualInversionFileConfig) + assert isinstance(config1, TI_File_Config) def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):