mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add ability to load FLUX kohya LoRA models that include patches for both the transformer and T5 models.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
@@ -7,14 +8,20 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
|
||||
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
@@ -23,7 +30,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||
return all(
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
@@ -35,12 +44,24 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
# Split the grouped state dict into transformer and T5 state dicts.
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_unet"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
t5_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
# Convert the state dicts to the InvokeAI format.
|
||||
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
|
||||
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
for layer_key, layer_state_dict in itertools.chain(transformer_grouped_sd.items(), t5_grouped_sd.items()):
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
@@ -50,16 +71,33 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
|
||||
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
|
||||
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
|
||||
"""
|
||||
converted_sd: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
converted_sd[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_sd
|
||||
|
||||
|
||||
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
|
||||
by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
@@ -70,9 +108,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ import torch
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
convert_flux_kohya_state_dict_to_invoke_format,
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
@@ -15,13 +15,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
|
||||
state_dict_keys as flux_kohya_te1_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
@@ -34,11 +38,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
|
||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
||||
# .alpha suffixes).
|
||||
@@ -65,29 +69,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
|
||||
unexpected keys.
|
||||
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
|
||||
state_dict contains unexpected keys.
|
||||
"""
|
||||
state_dict = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_kohya_state_dict():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||
|
||||
# Prepare expected layer keys.
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in flux_kohya_state_dict_keys:
|
||||
for k in sd_keys:
|
||||
# Remove prefixes.
|
||||
k = k.replace("lora_unet_", "")
|
||||
k = k.replace("lora_te1_", "")
|
||||
# Remove suffixes.
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
|
||||
Reference in New Issue
Block a user