From ee5808355d2195a85054e7aa51b9bad399fdf92d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:00:50 +1000 Subject: [PATCH] refactor(mm): diffusers loras w --- invokeai/app/api/routers/model_manager.py | 18 +++- invokeai/app/invocations/flux_ip_adapter.py | 5 +- invokeai/app/invocations/ip_adapter.py | 4 +- invokeai/backend/model_manager/config.py | 101 ++++++++++++------ .../flux_aitoolkit_lora_conversion_utils.py | 4 +- .../flux_control_lora_utils.py | 4 +- .../flux_diffusers_lora_conversion_utils.py | 6 +- .../flux_kohya_lora_conversion_utils.py | 3 +- .../flux_onetrainer_lora_conversion_utils.py | 3 +- 9 files changed, 102 insertions(+), 46 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index c91d2ed722..6dfc58df49 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -29,7 +29,13 @@ from invokeai.app.services.model_records import ( ) from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType -from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + Main_SD1_Checkpoint_Config, + Main_SD2_Checkpoint_Config, + Main_SDXL_Checkpoint_Config, + Main_SDXLRefiner_Checkpoint_Config, +) from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException @@ -738,7 +744,15 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=424, detail=str(e)) - if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig): + if isinstance( + model_config, + ( + Main_SD1_Checkpoint_Config, + Main_SD2_Checkpoint_Config, + Main_SDXL_Checkpoint_Config, + Main_SDXLRefiner_Checkpoint_Config, + ), + ): msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model." logger.error(msg) raise HTTPException(400, msg) diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index cfd166815d..970f330a58 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -17,8 +17,7 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( - IPAdapter_InvokeAI_Config_Base, - IPAdapterCheckpointConfig, + IPAdapter_FLUX_Checkpoint_Config, ) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType @@ -68,7 +67,7 @@ class FluxIPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, IPAdapter_FLUX_Checkpoint_Config) # Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy. image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model] diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 5b99f72369..7c3234bdc7 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( AnyModelConfig, + IPAdapter_Checkpoint_Config_Base, IPAdapter_InvokeAI_Config_Base, - IPAdapterCheckpointConfig, ) from invokeai.backend.model_manager.starter_models import ( StarterModel, @@ -123,7 +123,7 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base)) if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base): image_encoder_model_id = ip_adapter_info.image_encoder_model_id diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0b6e5fd83c..8d432b316a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -748,44 +748,78 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase) raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") -# LoRADiffusers_SupportedBases: TypeAlias = Literal[ -# BaseModelType.StableDiffusion1, -# BaseModelType.StableDiffusion2, -# BaseModelType.StableDiffusionXL, -# BaseModelType.Flux, -# ] +class LoRA_Diffusers_Config_Base(LoRAConfigBase): + """Model config for LoRA/Diffusers models.""" + + # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates + # the weights format. FLUX Diffusers LoRAs are single files. + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default.value + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA + path_to_weight_file = cls._get_weight_file_or_raise(mod) + state_dict = mod.load_state_dict(path_to_weight_file) + token_vector_length = lora_token_vector_length(state_dict) + + match token_vector_length: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") + + @classmethod + def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path: + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + for wf in weight_files: + if wf.exists(): + return wf + raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") -# class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): -# """Model config for LoRA/Diffusers models.""" +class LoRA_SD1_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -# # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates -# # the weights format. FLUX Diffusers LoRAs are single files. -# base: LoRADiffusers_SupportedBases = Field() -# format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) +class LoRA_SD2_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) -# @classmethod -# def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: -# _validate_is_dir(cls, mod) -# _validate_override_fields(cls, fields) +class LoRA_SDXL_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -# cls._validate_looks_like_diffusers_lora(mod) -# return cls(**fields) - -# @classmethod -# def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None: -# suffixes = ["bin", "safetensors"] -# weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] -# has_lora_weight_file = any(wf.exists() for wf in weight_files) -# if not has_lora_weight_file: -# raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") - -# flux_lora_format = _get_flux_lora_format(mod) -# if flux_lora_format is not FluxLoRAFormat.Diffusers: -# raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA") +class LoRA_FLUX_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) class VAE_Checkpoint_Config_Base(CheckpointConfigBase): @@ -2332,8 +2366,11 @@ AnyModelConfig = Annotated[ # LoRA - OMI format Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], - # LoRA - diffusers format (TODO) - # Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], + # LoRA - diffusers format + Annotated[LoRA_SD1_Diffusers_Config, LoRA_SD1_Diffusers_Config.get_tag()], + Annotated[LoRA_SD2_Diffusers_Config, LoRA_SD2_Diffusers_Config.get_tag()], + Annotated[LoRA_SDXL_Diffusers_Config, LoRA_SDXL_Diffusers_Config.get_tag()], + Annotated[LoRA_FLUX_Diffusers_Config, LoRA_FLUX_Diffusers_Config.get_tag()], # ControlLoRA - diffusers format Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py index db218d14bb..f3c202268a 100644 --- a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -13,7 +13,7 @@ from invokeai.backend.util import InvokeAILogger def is_state_dict_likely_in_flux_aitoolkit_format( - state_dict: dict[str, Any], + state_dict: dict[str | int, Any], metadata: dict[str, Any] | None = None, ) -> bool: if metadata: @@ -23,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format( return False return software.get("name") == "ai-toolkit" # metadata got lost somewhere - return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) + return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str)) @dataclass diff --git a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py index bd2b74e608..1762a4d5f4 100644 --- a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py @@ -25,7 +25,9 @@ def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool: perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ - all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys()) + all_keys_match = all( + re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str) + ) # Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs. lora_a_weight = state_dict.get("img_in.lora_A.weight", None) diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index 188d118cc4..f5b4bc6684 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -9,14 +9,16 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool: +def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely in the Diffusers FLUX LoRA format. This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ # First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format). - all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys()) + all_keys_in_peft_format = all( + k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str) + ) # Check if keys use transformer prefix transformer_prefix_keys = [ diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 7b5f346896..f5a6830c4f 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -44,7 +44,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*" -def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) ) diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py index 0413f0ef49..88aeee95e4 100644 --- a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py @@ -40,7 +40,7 @@ FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = ( ) -def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) - or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) )