mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port flux "control lora" and t2i adapter to new api
This commit is contained in:
@@ -66,6 +66,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
variant_type_adapter,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
|
||||
@@ -222,13 +223,7 @@ class ControlAdapterDefaultSettings(BaseModel):
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
|
||||
@classmethod
|
||||
def matches(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
|
||||
pass
|
||||
|
||||
|
||||
class ModelConfigBase(ABC, BaseModel):
|
||||
@@ -581,7 +576,7 @@ class ControlAdapterConfigBase(ABC, BaseModel):
|
||||
ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
|
||||
|
||||
|
||||
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
base: ControlLoRALyCORIS_SupportedBases = Field()
|
||||
@@ -590,20 +585,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, Model
|
||||
|
||||
trigger_phrases: set[str] | None = Field(None)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
if not is_state_dict_likely_flux_control(state_dict):
|
||||
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
|
||||
|
||||
|
||||
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
base: ControlLoRADiffusers_SupportedBases = Field()
|
||||
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
trigger_phrases: set[str] | None = Field(None)
|
||||
|
||||
|
||||
LoRADiffusers_SupportedBases: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
@@ -950,7 +946,7 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, L
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
|
||||
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
|
||||
upcast_attention: bool = False
|
||||
upcast_attention: bool = Field(False)
|
||||
|
||||
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
@@ -1236,13 +1232,34 @@ T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
|
||||
]
|
||||
|
||||
|
||||
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
base: T2IAdapterCheckpoint_SupportedBases = Field()
|
||||
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T2IAdapter,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T2IAdapter",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
@@ -1454,7 +1471,6 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
|
||||
Reference in New Issue
Block a user