Add a check that all keys are handled in the FLUX Diffusers LoRA loading code.

This commit is contained in:
Ryan Dick
2024-09-09 21:13:04 +00:00
parent 742f6781d5
commit da780c2243
2 changed files with 20 additions and 0 deletions

View File

@@ -192,6 +192,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)

View File

@@ -1,3 +1,6 @@
import pytest
import torch
from invokeai.backend.peft.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
@@ -47,3 +50,17 @@ def test_lora_model_from_flux_diffusers_state_dict():
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_diffusers_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_diffusers_state_dict(state_dict)