Add is_state_dict_likely_in_flux_diffusers_format(...) function with unit test.

This commit is contained in:
Ryan Dick
2024-09-09 19:02:15 +00:00
parent 1b406e6d6a
commit 2e8effe83f
2 changed files with 52 additions and 3 deletions

View File

@@ -9,6 +9,27 @@ from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
all_expected_keys_present = all(k in state_dict for k in expected_keys)
return all_keys_in_peft_format and all_expected_keys_present
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.

View File

@@ -1,16 +1,44 @@
import torch
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import state_dict_keys
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
def test_is_state_dict_likely_in_flux_diffusers_format_true():
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in flux_diffusers_state_dict_keys:
state_dict[k] = torch.empty(1)
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in flux_kohya_state_dict_keys:
state_dict[k] = torch.empty(1)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_lora_model_from_flux_diffusers_state_dict():
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
for k in flux_diffusers_state_dict_keys:
state_dict[k] = torch.empty(1)
# Load the state dict into a LoRAModelRaw object.
@@ -18,7 +46,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in state_dict_keys:
for k in flux_diffusers_state_dict_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
expected_lora_layers.add(k)