mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test.
This commit is contained in:
@@ -19,4 +19,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T
|
||||
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
replacement = r"\1.\2.\3.\4"
|
||||
|
||||
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}
|
||||
converted_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(pattern, k)
|
||||
if match:
|
||||
new_key = re.sub(pattern, replacement, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -37,3 +38,15 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
break
|
||||
if not found_match:
|
||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
|
||||
unexpected keys.
|
||||
"""
|
||||
state_dict = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
Reference in New Issue
Block a user