Add ability to load FLUX kohya LoRA models that include patches for both the transformer and T5 models.

This commit is contained in:
Ryan Dick
2024-09-26 19:41:56 +00:00
committed by Kent Keirsey
parent 21017edcde
commit d332d81866
3 changed files with 1205 additions and 26 deletions

View File

@@ -1,3 +1,4 @@
import itertools
import re
from typing import Any, Dict, TypeVar
@@ -7,14 +8,20 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
@@ -23,7 +30,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys()
)
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
@@ -35,12 +44,24 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Split the grouped state dict into transformer and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
for layer_key, layer_state_dict in itertools.chain(transformer_grouped_sd.items(), t5_grouped_sd.items()):
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
@@ -50,16 +71,33 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
"""
converted_sd: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
if match:
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
converted_sd[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_sd
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
@@ -70,9 +108,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")

View File

@@ -5,7 +5,7 @@ import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
@@ -15,13 +15,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
state_dict_keys as flux_kohya_te1_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
@@ -34,11 +38,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
@@ -65,29 +69,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
state_dict contains unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
_convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in flux_kohya_state_dict_keys:
for k in sd_keys:
# Remove prefixes.
k = k.replace("lora_unet_", "")
k = k.replace("lora_te1_", "")
# Remove suffixes.
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")