diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 7bc16099f0..5fc22be845 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -66,6 +66,7 @@ from invokeai.backend.model_manager.taxonomy import ( variant_type_adapter, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -222,13 +223,7 @@ class ControlAdapterDefaultSettings(BaseModel): class LegacyProbeMixin: """Mixin for classes using the legacy probe for model classification.""" - @classmethod - def matches(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}") - - @classmethod - def parse(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}") + pass class ModelConfigBase(ABC, BaseModel): @@ -581,7 +576,7 @@ class ControlAdapterConfigBase(ABC, BaseModel): ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux] -class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): +class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase): """Model config for Control LoRA models.""" base: ControlLoRALyCORIS_SupportedBases = Field() @@ -590,20 +585,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, Model trigger_phrases: set[str] | None = Field(None) + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _raise_if_not_file(cls, mod) + + state_dict = mod.load_state_dict() + + if not is_state_dict_likely_flux_control(state_dict): + raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") + + return cls(**fields) + ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux] -class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for Control LoRA models.""" - - base: ControlLoRADiffusers_SupportedBases = Field() - type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - trigger_phrases: set[str] | None = Field(None) - - LoRADiffusers_SupportedBases: TypeAlias = Literal[ BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, @@ -950,7 +946,7 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, L base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon) - upcast_attention: bool = False + upcast_attention: bool = Field(False) class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): @@ -1236,13 +1232,34 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[ ] -class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): +class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase): """Model config for T2I.""" base: T2IAdapterCheckpoint_SupportedBases = Field() type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + VALID_OVERRIDES: ClassVar = { + "type": ModelType.T2IAdapter, + "format": ModelFormat.Diffusers, + } + + VALID_CLASS_NAMES: ClassVar = { + "T2IAdapter", + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _raise_if_not_dir(cls, mod) + + _validate_overrides(cls, fields, cls.VALID_OVERRIDES) + + config_path = mod.path / "config.json" + + _validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES) + + return cls(**fields) + class SpandrelImageToImageConfig(ModelConfigBase): """Model config for Spandrel Image to Image models.""" @@ -1454,7 +1471,6 @@ AnyModelConfig = Annotated[ Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()], Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()], - Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()], diff --git a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py index fa9cc76462..bd2b74e608 100644 --- a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py @@ -18,7 +18,7 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)" -def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the FLUX Control LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A