Support FLUX LoRA models in kohya format with lora_te1 layers (i.e. CLIP LoRA layers) (#6967)

## Summary

This PR add support for FLUX LoRA models in kohya format with `lora_te1`
layers (i.e. CLIP LoRA layers). Previously, only transformer LoRA layers
were supported.

Example LoRA model in this format:
https://huggingface.co/cocktailpeanut/optimus

### Example

Prompt: `optimus is playing tennis in a tennis court`
Seed: 0

Without LoRA:

![image](https://github.com/user-attachments/assets/e9b20278-75e1-4940-a0c9-d8630f00b1e5)

With LoRA:

![image](https://github.com/user-attachments/assets/9f2e72e6-0118-4f98-a6d8-30ed59524f69)

## QA Instructions

I tested the following:

- [x] The optimus LoRA (with CLIP layers) can be applied.
- [x] FLUX LoRAs without CLIP layers still work
- [x] Loading the optimus LoRA, but applying it to the transformer
_only_ produces a different result. I.e. verified that patching the CLIP
layers is doing _something_. Ironically, the results seem better without
applying the CLIP layers. The CLIP layers seem to pull in more
background concepts. Regardless, it works.
- [x] The optimus LoRA can be applied via the Linear UI, and the output
matches results from manually constructing the workflow graph.
- [x] FLUX LoRAs without CLIP layers still work via the Linear UI.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick
2024-09-30 08:28:49 -04:00
committed by GitHub
9 changed files with 1332 additions and 56 deletions

View File

@@ -30,6 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
@@ -208,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
cached_weights=cached_weights,
)
)
@@ -219,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
dtype=inference_dtype,
)
)

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@@ -20,6 +20,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation(
@@ -27,21 +28,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
clip: CLIPField | None = InputField(
default=None,
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
@@ -49,18 +57,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
# Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
output = FluxLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
)
return FluxLoRALoaderOutput(transformer=transformer)
return output
@invocation(
@@ -68,7 +91,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
@@ -84,6 +107,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
input=Input.Connection,
title="Transformer",
)
clip: CLIPField | None = InputField(
default=None,
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
@@ -106,4 +135,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output

View File

@@ -1,4 +1,5 @@
from typing import Literal
from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@@ -9,6 +10,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -17,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
@@ -78,15 +83,42 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_KOHYA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -7,14 +7,25 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
@@ -23,7 +34,10 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
@@ -35,13 +49,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Split the grouped state dict into transformer and CLIP state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
@@ -50,16 +78,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
def _convert_flux_clip_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a CLIP LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
InvokeAI.
Example key conversions:
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
"""
converted_sd: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
if match:
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
converted_sd[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_sd
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
@@ -70,9 +116,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")

View File

@@ -8,7 +8,8 @@ export const addFLUXLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'flux_denoise'>,
modelLoader: Invocation<'flux_model_loader'>
modelLoader: Invocation<'flux_model_loader'>,
fluxTextEncoder: Invocation<'flux_text_encoder'>
): void => {
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
const loraCount = enabledLoRAs.length;
@@ -20,7 +21,7 @@ export const addFLUXLoRAs = (
const loraMetadata: S['LoRAMetadataField'][] = [];
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP.
// each LoRA to the transformer and text encoders.
const loraCollector = g.addNode({
id: getPrefixedId('lora_collector'),
type: 'collect',
@@ -33,10 +34,12 @@ export const addFLUXLoRAs = (
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use model loader as transformer input
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
// Reroute transformer connections through the LoRA collection loader
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
// Reroute model connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['transformer']);
g.deleteEdgesTo(fluxTextEncoder, ['clip']);
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip');
for (const lora of enabledLoRAs) {
const { weight } = lora;

View File

@@ -6,6 +6,7 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -18,8 +19,6 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addFLUXLoRAs } from './addFLUXLoRAs';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -96,12 +95,12 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
addFLUXLoRAs(state, g, noise, modelLoader);
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
addFLUXLoRAs(state, g, noise, modelLoader, posCond);
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
g.addEdge(noise, 'latents', l2i, 'latents');

View File

@@ -5707,6 +5707,12 @@ export type components = {
* @default null
*/
transformer?: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip?: components["schemas"]["CLIPField"] | null;
/**
* type
* @default flux_lora_collection_loader
@@ -6391,7 +6397,7 @@ export type components = {
};
/**
* FLUX LoRA
* @description Apply a LoRA model to a FLUX transformer.
* @description Apply a LoRA model to a FLUX transformer and/or text encoder.
*/
FluxLoRALoaderInvocation: {
/**
@@ -6428,7 +6434,13 @@ export type components = {
* @description Transformer
* @default null
*/
transformer?: components["schemas"]["TransformerField"];
transformer?: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip?: components["schemas"]["CLIPField"] | null;
/**
* type
* @default flux_lora_loader
@@ -6448,6 +6460,12 @@ export type components = {
* @default null
*/
transformer: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip: components["schemas"]["CLIPField"] | null;
/**
* type
* @default flux_lora_loader_output

View File

@@ -5,7 +5,9 @@ import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
FLUX_KOHYA_CLIP_PREFIX,
FLUX_KOHYA_TRANFORMER_PREFIX,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
@@ -15,13 +17,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
state_dict_keys as flux_kohya_te1_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
@@ -34,11 +40,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
@@ -65,29 +71,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
state_dict contains unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
_convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in flux_kohya_state_dict_keys:
k = k.replace("lora_unet_", "")
for k in sd_keys:
# Replace prefixes.
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
# Remove suffixes.
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")