mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 17:45:07 -05:00
Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test.
This commit is contained in:
@@ -19,4 +19,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T
|
||||
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
replacement = r"\1.\2.\3.\4"
|
||||
|
||||
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}
|
||||
converted_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(pattern, k)
|
||||
if match:
|
||||
new_key = re.sub(pattern, replacement, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
|
||||
Reference in New Issue
Block a user