Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test.

This commit is contained in:
Ryan Dick
2024-09-04 15:34:31 +00:00
committed by Kent Keirsey
parent 04b37e64ea
commit 7b5befad0d
2 changed files with 23 additions and 1 deletions

View File

@@ -19,4 +19,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
replacement = r"\1.\2.\3.\4"
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}
converted_dict: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
match = re.match(pattern, k)
if match:
new_key = re.sub(pattern, replacement, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict