diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 03fb8c66e1..9260656c1e 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -144,40 +144,26 @@ def _validate_overrides( config_class: type, provided_overrides: dict[str, Any], valid_overrides: dict[str, Any], -) -> bool: +) -> None: """Check if the provided overrides match the valid overrides for this config class. Args: config_class: The config class that is being tested. provided_overrides: The overrides provided by the user. - valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a - callable that takes the provided value and returns True if it is valid. - - Returns: - True if all provided overrides match the valid overrides, False if some valid overrides are missing. + valid_overrides: The overrides that are valid for this config class. Raises: NotAMatch if any override does not match the allowed value. """ - is_perfect_match = True - for override_name, constraint in valid_overrides.items(): - if override_name not in provided_overrides: - is_perfect_match = False + for key, value in valid_overrides.items(): + if key not in provided_overrides: continue - # Handle the typical case where the constraint is a specific value - if provided_overrides[override_name] != constraint: + if provided_overrides[key] != value: raise NotAMatch( config_class, - f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}", - ) - # Handle the less common case where the constraint is a callable - elif callable(constraint) and not constraint(provided_overrides[override_name]): - raise NotAMatch( - config_class, - f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable", + f"override {key}={provided_overrides[key]} does not match required value {key}={value}", ) - return is_perfect_match def _raise_if_not_file( @@ -422,8 +408,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) @@ -452,8 +437,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) # Heuristic: Look for the T5EncoderModel class name in the config _validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES) @@ -483,8 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): # OMI LoRAs are always files _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) # Heuristic: differential diagnosis vs ControlLoRA and Diffusers if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: @@ -532,8 +515,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) # Heuristic: differential diagnosis vs ControlLoRA and Diffusers if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: @@ -603,8 +585,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): # Diffusers-style models always directories _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers @@ -636,8 +617,7 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") @@ -678,8 +658,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) _validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES) @@ -813,8 +792,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) if not cls._file_looks_like_embedding(mod): raise NotAMatch(cls, "model does not look like a textual inversion embedding file") @@ -841,8 +819,7 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) for p in mod.weight_files(): if cls._file_looks_like_embedding(mod, p): @@ -903,15 +880,6 @@ class IPAdapterConfigBase(ABC, BaseModel): type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter -IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, -] -"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format.""" - -ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes) -"""Helper TypeAdapter for IP Adapter InvokeAI base types.""" class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): @@ -921,20 +889,17 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): # time. Need to go through the history to make sure I'm understanding this fully. image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI - base: IPAdapterInvokeAIConfigBaseTypes = Field(...) VALID_OVERRIDES: ClassVar = { "type": ModelType.IPAdapter, "format": ModelFormat.InvokeAI, - "base": ip_adapter_invoke_ai_base_type_adapter.validate_python, } @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) weights_file = mod.path / "ip_adapter.bin" if not weights_file.exists(): @@ -948,7 +913,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): return cls(**fields, base=base) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: state_dict = mod.load_state_dict() try: @@ -967,18 +932,6 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase): raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") -IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, -] -"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format.""" - -ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes]( - IPAdapterCheckpointConfigBaseTypes -) -"""Helper TypeAdapter for IP Adapter Checkpoint base types.""" class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): @@ -989,15 +942,13 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): VALID_OVERRIDES: ClassVar = { "type": ModelType.IPAdapter, "format": ModelFormat.Checkpoint, - "base": ip_adapter_checkpoint_base_type_adapter.validate_python, } @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) if not mod.has_keys_starting_with( { @@ -1013,7 +964,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase): return cls(**fields, base=base) @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: state_dict = mod.load_state_dict() if is_state_dict_xlabs_ip_adapter(state_dict): @@ -1083,8 +1034,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) config_path = mod.path / "config.json" @@ -1119,8 +1069,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) config_path = mod.path / "config.json" @@ -1154,8 +1103,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) config_path = mod.path / "config.json" @@ -1188,8 +1136,7 @@ class SpandrelImageToImageConfig(ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) try: # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were @@ -1226,8 +1173,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) config_path = mod.path / "config.json" @@ -1251,8 +1197,7 @@ class FluxReduxConfig(ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_file(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) if not is_state_dict_likely_flux_redux(mod.load_state_dict()): raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics") @@ -1280,8 +1225,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: _raise_if_not_dir(cls, mod) - if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): - return cls(**fields) + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) config_path = mod.path / "config.json"