mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add utility function for detecting whether a state_dict is in the FLUX kohya LoRA format.
This commit is contained in:
@@ -1,8 +1,29 @@
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
for k in state_dict.keys():
|
||||
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
@@ -16,14 +37,13 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.T
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight" -> "double_blocks.0.img.attn.qkv.lora_up.weight"
|
||||
|
||||
"""
|
||||
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
replacement = r"\1.\2.\3.\4"
|
||||
|
||||
converted_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(pattern, k)
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(pattern, replacement, k)
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replacement, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
Reference in New Issue
Block a user