From 7b5befad0d206ef58db192e6ecd3d277b7e8475e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 4 Sep 2024 15:34:31 +0000 Subject: [PATCH] Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test. --- .../lora/conversions/flux_lora_conversion_utils.py | 11 ++++++++++- .../conversions/test_flux_lora_conversion_utils.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py index 18370b0289..dd58dd52eb 100644 --- a/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py @@ -19,4 +19,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" replacement = r"\1.\2.\3.\4" - return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()} + converted_dict: dict[str, torch.Tensor] = {} + for k, v in state_dict.items(): + match = re.match(pattern, k) + if match: + new_key = re.sub(pattern, replacement, k) + converted_dict[new_key] = v + else: + raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") + + return converted_dict diff --git a/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py index b72f1276f2..04fd205e4c 100644 --- a/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py @@ -1,3 +1,4 @@ +import pytest import torch from invokeai.backend.flux.model import Flux @@ -37,3 +38,15 @@ def test_convert_flux_kohya_state_dict_to_invoke_format(): break if not found_match: raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}") + + +def test_convert_flux_kohya_state_dict_to_invoke_format_error(): + """Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains + unexpected keys. + """ + state_dict = { + "unexpected_key.lora_up.weight": torch.empty(1), + } + + with pytest.raises(ValueError): + convert_flux_kohya_state_dict_to_invoke_format(state_dict)