feat(mm): port flux "control lora" and t2i adapter to new api

This commit is contained in:
psychedelicious
2025-09-25 20:41:58 +10:00
parent eaddd6f533
commit a118700cc8
2 changed files with 38 additions and 22 deletions

View File

@@ -66,6 +66,7 @@ from invokeai.backend.model_manager.taxonomy import (
variant_type_adapter,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -222,13 +223,7 @@ class ControlAdapterDefaultSettings(BaseModel):
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
pass
class ModelConfigBase(ABC, BaseModel):
@@ -581,7 +576,7 @@ class ControlAdapterConfigBase(ABC, BaseModel):
ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
"""Model config for Control LoRA models."""
base: ControlLoRALyCORIS_SupportedBases = Field()
@@ -590,20 +585,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, Model
trigger_phrases: set[str] | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
state_dict = mod.load_state_dict()
if not is_state_dict_likely_flux_control(state_dict):
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
return cls(**fields)
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
base: ControlLoRADiffusers_SupportedBases = Field()
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
trigger_phrases: set[str] | None = Field(None)
LoRADiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
@@ -950,7 +946,7 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, L
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = False
upcast_attention: bool = Field(False)
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
@@ -1236,13 +1232,34 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
]
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for T2I."""
base: T2IAdapterCheckpoint_SupportedBases = Field()
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T2IAdapter,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"T2IAdapter",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
return cls(**fields)
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
@@ -1454,7 +1471,6 @@ AnyModelConfig = Annotated[
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

View File

@@ -18,7 +18,7 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A