Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test.

This commit is contained in:
Ryan Dick
2024-09-04 15:34:31 +00:00
committed by Kent Keirsey
parent 04b37e64ea
commit 7b5befad0d
2 changed files with 23 additions and 1 deletions

View File

@@ -19,4 +19,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
replacement = r"\1.\2.\3.\4"
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}
converted_dict: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
match = re.match(pattern, k)
if match:
new_key = re.sub(pattern, replacement, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -1,3 +1,4 @@
import pytest
import torch
from invokeai.backend.flux.model import Flux
@@ -37,3 +38,15 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)