mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add a check that all keys are handled in the FLUX Diffusers LoRA loading code.
This commit is contained in:
@@ -192,6 +192,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
# Final layer.
|
||||
add_lora_layer_if_present("proj_out", "final_layer.linear")
|
||||
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
@@ -47,3 +50,17 @@ def test_lora_model_from_flux_diffusers_state_dict():
|
||||
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
|
||||
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
|
||||
"""Test that lora_model_from_flux_diffusers_state_dict() raises an error if the input state_dict contains unexpected
|
||||
keys that we don't handle.
|
||||
"""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
# Add an unexpected key.
|
||||
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
|
||||
|
||||
# Check that an error is raised.
|
||||
with pytest.raises(AssertionError):
|
||||
lora_model_from_flux_diffusers_state_dict(state_dict)
|
||||
|
||||
Reference in New Issue
Block a user