mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(mm): flux lora format util
This commit is contained in:
@@ -540,16 +540,12 @@ class LoRAConfigBase(ABC, BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
|
||||
key = "FLUX_LORA_FORMAT"
|
||||
if key in mod.cache:
|
||||
return mod.cache[key]
|
||||
|
||||
def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
|
||||
# TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later.
|
||||
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
||||
|
||||
sd = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(sd, mod.metadata())
|
||||
mod.cache[key] = value
|
||||
state_dict = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(state_dict, mod.metadata())
|
||||
return value
|
||||
|
||||
|
||||
@@ -574,7 +570,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
@classmethod
|
||||
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
|
||||
flux_format = get_flux_lora_format(mod)
|
||||
flux_format = _get_flux_lora_format(mod)
|
||||
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
|
||||
|
||||
@@ -663,13 +659,13 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
@classmethod
|
||||
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
|
||||
flux_format = get_flux_lora_format(mod)
|
||||
flux_format = _get_flux_lora_format(mod)
|
||||
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
|
||||
if get_flux_lora_format(mod):
|
||||
if _get_flux_lora_format(mod):
|
||||
return BaseModelType.Flux
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
@@ -752,7 +748,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
|
||||
flux_lora_format = get_flux_lora_format(mod)
|
||||
flux_lora_format = _get_flux_lora_format(mod)
|
||||
if flux_lora_format is not FluxLoRAFormat.Diffusers:
|
||||
raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user