mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add is_state_dict_likely_in_flux_diffusers_format(...) function with unit test.
This commit is contained in:
@@ -9,6 +9,27 @@ from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
|
||||
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||
expected_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
@@ -1,16 +1,44 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from tests.backend.peft.conversions.lora_state_dicts.flux_lora_diffusers_format import state_dict_keys
|
||||
from tests.backend.peft.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.peft.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_true():
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the Kohya FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in flux_kohya_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_diffusers_state_dict():
|
||||
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
# Load the state dict into a LoRAModelRaw object.
|
||||
@@ -18,7 +46,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
|
||||
|
||||
# Check that the model has the correct number of LoRA layers.
|
||||
expected_lora_layers: set[str] = set()
|
||||
for k in state_dict_keys:
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
k = k.replace("lora_A.weight", "")
|
||||
k = k.replace("lora_B.weight", "")
|
||||
expected_lora_layers.add(k)
|
||||
|
||||
Reference in New Issue
Block a user