tidy(mm): flux lora format util

This commit is contained in:
psychedelicious
2025-10-01 14:46:38 +10:00
parent a9b88d46e2
commit a7f1cf4c17

View File

@@ -540,16 +540,12 @@ class LoRAConfigBase(ABC, BaseModel):
)
def get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
# TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later.
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
state_dict = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(state_dict, mod.metadata())
return value
@@ -574,7 +570,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
flux_format = get_flux_lora_format(mod)
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
@@ -663,13 +659,13 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
flux_format = get_flux_lora_format(mod)
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
if get_flux_lora_format(mod):
if _get_flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
@@ -752,7 +748,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
flux_lora_format = get_flux_lora_format(mod)
flux_lora_format = _get_flux_lora_format(mod)
if flux_lora_format is not FluxLoRAFormat.Diffusers:
raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")