Compare commits

...

17 Commits

Author SHA1 Message Date
Ryan Dick
67afa7e339 Update PartialLayer to work with unquantized / GGML quantized / BnB quantized layers. 2025-01-24 15:52:57 +00:00
Ryan Dick
92c6a7d658 Add FLUX OneTrainer model probing. 2025-01-23 23:45:22 +00:00
Ryan Dick
caa9ecafae Finish switching LoRAModelRaw to use a list of names/layers rather than a dict. This enables multiple entries with the same name. 2025-01-23 23:35:34 +00:00
Ryan Dick
f21464e972 Relax lora_layers_from_flux_diffusers_grouped_state_dict(...) so that it can work with more LoRA variants (e.g. hada) 2025-01-23 23:11:55 +00:00
Ryan Dick
eeff9d3df5 Fix bug in FLUX T5 Koyha-style LoRA key parsing. 2025-01-23 23:11:55 +00:00
Ryan Dick
d9c1c7d63d Update FLUX invocations to support LoRAs that modify the T5 text encoder. 2025-01-23 23:11:55 +00:00
Ryan Dick
420f6feef9 Fix typo in DoRALayer. 2025-01-23 23:11:55 +00:00
Ryan Dick
8d09a36c90 WIP - use the PartialLayer instead of ConcatenatedLoRALayer when loading diffusers LoRAs. 2025-01-23 23:11:55 +00:00
Ryan Dick
6f82be4dc4 Add PartialLayer for applying patches to a sub-range of a target weight. 2025-01-23 23:11:55 +00:00
Ryan Dick
9dfbd6a422 Add utils for loading FLUX OneTrainer DoRA models. 2025-01-23 23:11:55 +00:00
Ryan Dick
a10db807ca Further updates to lora_model_from_flux_diffusers_state_dict() so that it can be re-used for OneTrainer LoRAs. 2025-01-23 23:11:55 +00:00
Ryan Dick
edc0b63612 Add support for LyCoris-style LoRA keys in lora_model_from_flux_diffusers_state_dict(). Previously, it only supported PEFT-style LoRA keys. 2025-01-23 23:11:55 +00:00
Ryan Dick
d44a6b2ca1 Add utils for working with Kohya LoRA keys. 2025-01-23 23:11:55 +00:00
Ryan Dick
0d2c1b9d8f First draft of DoRALayer. Not tested yet. 2025-01-23 23:11:55 +00:00
Ryan Dick
9952b19c5d Expand unit tests to test for confusion between FLUX LoRA formats. 2025-01-23 23:11:55 +00:00
Ryan Dick
2335b70dba Add is_state_dict_likely_in_flux_onetrainer_format() util function. 2025-01-23 23:11:55 +00:00
Ryan Dick
010383faef Add a test state dict for the OneTrainer DoRA format. 2025-01-23 23:11:55 +00:00
27 changed files with 2951 additions and 129 deletions

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@@ -21,6 +21,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
)
@invocation(
@@ -28,7 +31,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
@@ -50,6 +53,12 @@ class FluxLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
@@ -62,6 +71,8 @@ class FluxLoRALoaderInvocation(BaseInvocation):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
output = FluxLoRALoaderOutput()
@@ -82,6 +93,14 @@ class FluxLoRALoaderInvocation(BaseInvocation):
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@@ -91,7 +110,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
@@ -113,6 +132,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
@@ -140,4 +165,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.t5_encoder is not None:
if output.t5_encoder is None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(lora)
return output

View File

@@ -40,7 +40,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
version="1.0.5",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
@@ -87,7 +87,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -19,7 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -71,13 +71,34 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_encoder_config = t5_encoder_info.config
assert t5_encoder_config is not None
with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
# Apply LoRA models to the T5 encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if t5_encoder_config.format == ModelFormat.T5Encoder:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=t5_text_encoder,
patches=self._t5_lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=t5_text_encoder.dtype,
cached_weights=cached_weights,
)
)
else:
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
context.util.signal_progress("Running T5 encoder")
@@ -132,3 +153,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.t5_encoder.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -68,6 +68,7 @@ class CLIPField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):

View File

@@ -31,6 +31,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@@ -84,8 +88,12 @@ class LoRALoader(ModelLoader):
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:

View File

@@ -46,6 +46,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -283,7 +286,7 @@ class ModelProbe(object):
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return ModelType.LoRA
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
@@ -632,6 +635,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):

View File

@@ -81,11 +81,11 @@ class LayerPatcher:
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
layer_keys_are_flattened = "." not in next(iter(patch.layers))[0]
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
for layer_key, layer in patch.layers:
if not layer_key.startswith(prefix):
continue

View File

@@ -2,7 +2,6 @@ from typing import Optional, Sequence
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
@@ -14,7 +13,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
def __init__(self, lora_layers: Sequence[LoRALayerBase], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = lora_layers

View File

@@ -0,0 +1,92 @@
from typing import Dict, Optional
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class DoRALayer(LoRALayerBase):
"""A DoRA layer. As defined in https://arxiv.org/pdf/2402.09353."""
def __init__(
self,
up: torch.Tensor,
down: torch.Tensor,
dora_scale: torch.Tensor,
alpha: float | None,
bias: Optional[torch.Tensor],
):
super().__init__(alpha, bias)
self.up = up
self.down = down
self.dora_scale = dora_scale
@classmethod
def from_state_dict_values(cls, values: Dict[str, torch.Tensor]):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
dora_scale=values["dora_scale"],
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"dora_scale",
},
)
return layer
def _rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# Note: Variable names (e.g. delta_v) are based on the paper.
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
delta_v = delta_v.reshape(orig_weight.shape)
# TODO(ryand): Should alpha be applied to delta_v here rather than the final diff?
# TODO(ryand): I expect this to fail if the original weight is BnB Quantized. This class shouldn't have to worry
# about that, but we should add a clear error message further up the stack.
# At this point, out_weight is the unnormalized direction matrix.
out_weight = orig_weight + delta_v
# TODO(ryand): Simplify this logic.
direction_norm = (
out_weight.transpose(0, 1)
.reshape(out_weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(out_weight.shape[1], *[1] * (out_weight.dim() - 1))
.transpose(0, 1)
)
out_weight *= self.dora_scale / direction_norm
return out_weight - orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale])

View File

@@ -0,0 +1,83 @@
from dataclasses import dataclass
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
@dataclass
class Range:
start: int
end: int
class PartialLayer(BaseLayerPatch):
"""A layer patch that only modifies a sub-range of the weights in the original layer.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layer: LoRALayerBase, range: tuple[Range, Range]):
super().__init__()
self.lora_layer = lora_layer
# self.range[i] gives the range to be modified in the original layer for the i'th dimension.
self.range = range
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
# HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with
# dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But,
# of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_parameters.keys():
param = orig_parameters[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
pass
else:
orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta")
# Slice the original parameters to the specified range.
sliced_parameters: dict[str, torch.Tensor] = {}
for param_name, param_weight in orig_parameters.items():
if param_name == "weight":
sliced_parameters[param_name] = param_weight[
self.range[0].start : self.range[0].end, self.range[1].start : self.range[1].end
]
elif param_name == "bias":
sliced_parameters[param_name] = param_weight[self.range[0].start : self.range[0].end]
else:
raise ValueError(f"Unexpected parameter name: {param_name}")
# Apply the LoRA layer to the sliced parameters.
params = self.lora_layer.get_parameters(sliced_parameters, weight)
# Expand the parameters diffs to match the original parameter shape.
out_params: dict[str, torch.Tensor] = {}
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
out_params[param_name] = torch.zeros(
get_param_shape(orig_param), dtype=param_weight.dtype, device=param_weight.device
)
if param_name == "weight":
out_params[param_name][
self.range[0].start : self.range[0].end, self.range[1].start : self.range[1].end
] = param_weight
elif param_name == "bias":
out_params[param_name][self.range[0].start : self.range[0].end] = param_weight
else:
raise ValueError(f"Unexpected parameter name: {param_name}")
return out_params
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.lora_layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return self.lora_layer.calc_size()

View File

@@ -3,6 +3,7 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.dora_layer import DoRALayer
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
from invokeai.backend.patches.layers.loha_layer import LoHALayer
@@ -14,8 +15,9 @@ from invokeai.backend.patches.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
if "dora_scale" in state_dict:
return DoRALayer.from_state_dict_values(state_dict)
elif "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer.from_state_dict_values(state_dict)
elif "hada_w1_a" in state_dict:

View File

@@ -56,28 +56,38 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
grouped_state_dict[layer_name][param_name] = value
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
layers: list[tuple[str, BaseLayerPatch]] = []
for layer_key, layer_state_dict in grouped_state_dict.items():
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if layer_key == "img_in":
# img_in is a special case because it changes the shape of the original weight.
layers[prefixed_key] = FluxControlLoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
layers.append(
(
prefixed_key,
FluxControlLoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
),
)
)
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
layers.append(
(
prefixed_key,
LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
),
)
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
layers.append((prefixed_key, SetParameterLayer("scale", layer_state_dict["scale"])))
else:
raise ValueError(f"{layer_key} not expected")

View File

@@ -3,8 +3,8 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.partial_layer import PartialLayer, Range
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -33,13 +33,21 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
layers = lora_layers_from_flux_diffusers_grouped_state_dict(grouped_state_dict, alpha)
return ModelPatchRaw(layers=layers)
def lora_layers_from_flux_diffusers_grouped_state_dict(
grouped_state_dict: Dict[str, Dict[str, torch.Tensor]], alpha: float | None
) -> list[tuple[str, BaseLayerPatch]]:
"""Converts a grouped state dict with Diffusers FLUX LoRA keys to LoRA layers with BFL keys (i.e. the module key
format used by Invoke).
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
@@ -51,19 +59,28 @@ def lora_model_from_flux_diffusers_state_dict(
mlp_ratio = 4.0
mlp_hidden_dim = int(hidden_size * mlp_ratio)
layers: dict[str, BaseLayerPatch] = {}
layers: list[tuple[str, BaseLayerPatch]] = []
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
value = {
def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if "lora_A.weight" in src_layer_dict:
# The LoRA keys are in PEFT format.
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
value["alpha"] = torch.tensor(alpha)
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
values["alpha"] = torch.tensor(alpha)
assert len(src_layer_dict) == 0
return values
else:
# Assume that the LoRA keys are in Kohya format.
return src_layer_dict
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers.append((dst_key, any_lora_layer_from_state_dict(values)))
def add_qkv_lora_layer_if_present(
src_keys: list[str],
@@ -79,29 +96,29 @@ def lora_model_from_flux_diffusers_state_dict(
if not any(keys_present):
return
sub_layers: list[LoRALayer] = []
dim_0_offset = 0
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
values = get_lora_layer_values(src_layer_dict)
# assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
# assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
layers.append(
(
dst_qkv_key,
PartialLayer(
any_lora_layer_from_state_dict(values),
(
Range(dim_0_offset, dim_0_offset + src_weight_shape[0]),
Range(0, src_weight_shape[1]),
),
),
)
)
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
dim_0_offset += src_weight_shape[0]
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
@@ -215,9 +232,9 @@ def lora_model_from_flux_diffusers_state_dict(
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
layers_with_prefix = [(f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}", v) for k, v in layers]
return ModelPatchRaw(layers=layers_with_prefix)
return layers_with_prefix
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

View File

@@ -7,6 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -26,6 +27,14 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.alpha
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
@@ -34,7 +43,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)
@@ -48,27 +59,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer and CLIP state dicts.
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers: list[tuple[str, BaseLayerPatch]] = []
for model_prefix, grouped_sd in [
(FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers.append((model_prefix + layer_key, any_lora_layer_from_state_dict(layer_state_dict)))
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
@@ -123,3 +141,31 @@ def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
InvokeAI.
Example key conversions:
"lora_te2_encoder_block_0_layer_0_SelfAttention_k" -> "encoder.block.0.layer.0.SelfAttention.k"
"lora_te2_encoder_block_0_layer_1_DenseReluDense_wi_0" -> "encoder.block.0.layer.1.DenseReluDense.wi.0"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"encoder.block.{match.group(1)}.layer.{match.group(2)}.{match.group(3)}.{match.group(4)}"
if match.group(5):
s += f".{match.group(5)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_T5_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -1,3 +1,4 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
FLUX_LORA_T5_PREFIX = "lora_t5-"

View File

@@ -0,0 +1,163 @@
import re
from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
lora_layers_from_flux_diffusers_grouped_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_KEY_REGEX,
FLUX_KOHYA_T5_KEY_REGEX,
_convert_flux_clip_kohya_state_dict_to_invoke_format,
_convert_flux_t5_kohya_state_dict_to_invoke_format,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
)
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
INDEX_PLACEHOLDER,
ParsingTree,
insert_periods_into_kohya_key,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format.
# The OneTrainer format uses a mix of the Kohya and Diffusers formats:
# - The base model keys are in Diffusers format.
# - Periods are replaced with underscores, to match Kohya.
# - The LoRA key suffixes (e.g. .alpha, .lora_down.weight, .lora_up.weight) match Kohya.
# Example keys:
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.alpha"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.dora_scale"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_up.weight"
FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
r"lora_transformer_(single_transformer_blocks|transformer_blocks)_(\d+)_(\w+)\.(.*)"
)
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
Note that OneTrainer matches the Kohya format for the CLIP and T5 models.
"""
return all(
re.match(FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_onetrainer_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: # type: ignore
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_transformer"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: list[tuple[str, BaseLayerPatch]] = []
for model_prefix, grouped_sd in [
# (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers.append((model_prefix + layer_key, any_lora_layer_from_state_dict(layer_state_dict)))
# Handle the transformer.
transformer_layers = _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(transformer_grouped_sd)
layers.extend(transformer_layers)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
# This parsing tree was generated by calling `generate_kohya_parsing_tree_from_keys()` on the keys in
# flux_lora_diffusers_format.py.
flux_transformer_kohya_parsing_tree: ParsingTree = {
"transformer": {
"single_transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {"to_k": {}, "to_q": {}, "to_v": {}},
"norm": {"linear": {}},
"proj_mlp": {},
"proj_out": {},
}
},
"transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {
"add_k_proj": {},
"add_q_proj": {},
"add_v_proj": {},
"to_add_out": {},
"to_k": {},
"to_out": {INDEX_PLACEHOLDER: {}},
"to_q": {},
"to_v": {},
},
"ff": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"ff_context": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"norm1": {"linear": {}},
"norm1_context": {"linear": {}},
}
},
}
}
def _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(
state_dict: Dict[str, Dict[str, torch.Tensor]],
) -> list[tuple[str, BaseLayerPatch]]:
"""Converts a FLUX transformer LoRA state dict from the OneTrainer FLUX LoRA format to the LoRA weight format used
internally by InvokeAI.
"""
# Step 1: Convert the Kohya-style keys with underscores to classic keys with periods.
# Example:
# "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" -> "transformer.single_transformer_blocks.0.attn.to_k.lora_down.weight"
lora_prefix = "lora_"
lora_prefix_length = len(lora_prefix)
kohya_state_dict: dict[str, Dict[str, torch.Tensor]] = {}
for key in state_dict.keys():
# Remove the "lora_" prefix.
assert key.startswith(lora_prefix)
new_key = key[lora_prefix_length:]
# Add periods to the Kohya-style module keys.
new_key = insert_periods_into_kohya_key(new_key, flux_transformer_kohya_parsing_tree)
# Replace the old key with the new key.
kohya_state_dict[new_key] = state_dict[key]
# Step 2: Convert diffusers module names to the BFL module names.
return lora_layers_from_flux_diffusers_grouped_state_dict(kohya_state_dict, alpha=None)

View File

@@ -0,0 +1,102 @@
from typing import Iterable
INDEX_PLACEHOLDER = "index_placeholder"
# Type alias for a 'ParsingTree', which is a recursive dict with string keys.
ParsingTree = dict[str, "ParsingTree"]
def insert_periods_into_kohya_key(key: str, parsing_tree: ParsingTree) -> str:
"""Insert periods into a Kohya key based on a parsing tree.
Kohya format keys are produced by replacing periods with underscores in the original key.
Example:
```
key = "module_a_module_b_0_attn_to_k"
parsing_tree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {},
},
},
},
}
result = insert_periods_into_kohya_key(key, parsing_tree)
> "module_a.module_b.0.attn.to_k"
```
"""
# Split key into parts by underscore.
parts = key.split("_")
# Build up result by walking through parsing tree and parts.
result_parts: list[str] = []
current_part = ""
current_tree = parsing_tree
for part in parts:
if len(current_part) > 0:
current_part = current_part + "_"
current_part += part
if current_part in current_tree:
# Match found.
current_tree = current_tree[current_part]
result_parts.append(current_part)
current_part = ""
elif current_part.isnumeric() and INDEX_PLACEHOLDER in current_tree:
# Match found with index placeholder.
current_tree = current_tree[INDEX_PLACEHOLDER]
result_parts.append(current_part)
current_part = ""
if len(current_part) > 0:
raise ValueError(f"Key {key} does not match parsing tree {parsing_tree}.")
return ".".join(result_parts)
def generate_kohya_parsing_tree_from_keys(keys: Iterable[str]) -> ParsingTree:
"""Generate a parsing tree from a list of keys.
Example:
```
keys = [
"module_a.module_b.0.attn.to_k",
"module_a.module_b.1.attn.to_k",
"module_a.module_c.proj",
]
tree = generate_kohya_parsing_tree_from_keys(keys)
> {
> "module_a": {
> "module_b": {
> INDEX_PLACEHOLDER: {
> "attn": {
> "to_k": {},
> "to_q": {},
> },
> }
> },
> "module_c": {
> "proj": {},
> }
> }
> }
```
"""
tree: ParsingTree = {}
for key in keys:
subtree: ParsingTree = tree
for module_name in key.split("."):
key = module_name
if module_name.isnumeric():
key = INDEX_PLACEHOLDER
if key not in subtree:
subtree[key] = {}
subtree = subtree[key]
return tree

View File

@@ -10,9 +10,9 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, BaseLayerPatch] = {}
layers: list[tuple[str, BaseLayerPatch]] = []
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
layers.append((layer_key, any_lora_layer_from_state_dict(values)))
return ModelPatchRaw(layers=layers)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 The InvokeAI Development team
from typing import Mapping, Optional
from typing import Optional, Sequence
import torch
@@ -8,12 +8,12 @@ from invokeai.backend.raw_model import RawModel
class ModelPatchRaw(RawModel):
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
def __init__(self, layers: Sequence[tuple[str, BaseLayerPatch]]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for layer in self.layers.values():
for _, layer in self.layers:
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return sum(layer.calc_size() for layer in self.layers.values())
return sum(layer.calc_size() for _, layer in self.layers)

View File

@@ -54,7 +54,9 @@ GGML_TENSOR_OP_TABLE = {
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():

View File

@@ -51,10 +51,8 @@ def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
k = k.replace("lora_B.bias", "")
k = k.replace(".scale", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k, _ in model.layers)
def test_lora_model_from_flux_control_state_dict_extra_keys_error():

View File

@@ -6,6 +6,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@@ -27,12 +30,13 @@ def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, l
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_diffusers_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
@@ -51,12 +55,8 @@ def test_lora_model_from_flux_diffusers_state_dict(sd_keys: dict[str, list[int]]
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k, _ in model.layers)
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():

View File

@@ -13,6 +13,9 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@@ -34,11 +37,12 @@ def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
@@ -106,6 +110,6 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
expected_layer_keys.add(k)
# Assert that the lora_model has the expected layers.
lora_model_keys = set(lora_model.layers.keys())
lora_model_keys = {k for k, _ in lora_model.layers}
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
assert lora_model_keys == expected_layer_keys

View File

@@ -0,0 +1,73 @@
import pytest
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
state_dict_keys as flux_kohya_te1_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_onetrainer_format_true():
"""Test that is_state_dict_likely_in_flux_onetrainer_format() can identify a state dict in the OneTrainer
FLUX LoRA format.
"""
# Construct a state dict that is in the OneTrainer FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys)
assert is_state_dict_likely_in_flux_onetrainer_format(state_dict)
@pytest.mark.parametrize(
"sd_keys",
[
flux_kohya_state_dict_keys,
flux_kohya_te1_state_dict_keys,
flux_diffusers_state_dict_keys,
],
)
def test_is_state_dict_likely_in_flux_onetrainer_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_onetrainer_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_onetrainer_format(state_dict)
def test_lora_model_from_flux_onetrainer_state_dict():
state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys)
lora_model = lora_model_from_flux_onetrainer_state_dict(state_dict)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in flux_onetrainer_state_dict_keys:
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
k = k.replace(".dora_scale", "")
expected_lora_layers.add(k)
assert len(lora_model.layers) == len(expected_lora_layers)
# Check that all of the layers have the expected prefix.
assert all(
k.startswith((FLUX_LORA_TRANSFORMER_PREFIX, FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX))
for k, _ in lora_model.layers
)

View File

@@ -0,0 +1,96 @@
import pytest
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
INDEX_PLACEHOLDER,
ParsingTree,
generate_kohya_parsing_tree_from_keys,
insert_periods_into_kohya_key,
)
def test_insert_periods_into_kohya_key():
"""Test that insert_periods_into_kohya_key() correctly inserts periods into a Kohya key."""
key = "module_a_module_b_0_attn_to_k"
parsing_tree: ParsingTree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {
"to_k": {},
},
},
},
},
}
result = insert_periods_into_kohya_key(key, parsing_tree)
assert result == "module_a.module_b.0.attn.to_k"
def test_insert_periods_into_kohya_key_invalid_key():
"""Test that insert_periods_into_kohya_key() raises ValueError for a key that is invalid."""
key = "invalid_key_format"
parsing_tree: ParsingTree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {
"to_k": {},
},
},
},
},
}
with pytest.raises(ValueError):
insert_periods_into_kohya_key(key, parsing_tree)
def test_insert_periods_into_kohya_key_too_long():
"""Test that insert_periods_into_kohya_key() raises ValueError for a key that has a valid prefix, but is too long."""
key = "module_a.module_b.0.attn.to_k.invalid_suffix"
parsing_tree: ParsingTree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {
"to_k": {},
},
},
},
},
}
with pytest.raises(ValueError):
insert_periods_into_kohya_key(key, parsing_tree)
def test_generate_kohya_parsing_tree_from_keys():
"""Test that generate_kohya_parsing_tree_from_keys() correctly generates a parsing tree."""
keys = [
"module_a.module_b.0.attn.to_k",
"module_a.module_b.1.attn.to_k",
"module_a.module_c.proj",
]
expected_tree: ParsingTree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {
"to_k": {},
},
}
},
"module_c": {
"proj": {},
},
}
}
tree = generate_kohya_parsing_tree_from_keys(keys)
assert tree == expected_tree
def test_generate_kohya_parsing_tree_from_empty_keys():
"""Test that generate_kohya_parsing_tree_from_keys() handles empty input."""
keys: list[str] = []
tree = generate_kohya_parsing_tree_from_keys(keys)
assert tree == {}

View File

@@ -58,14 +58,21 @@ def test_apply_smart_model_patches(
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
lora_layers = [
(
"linear_layer_1",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones(
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
),
"lora_up.weight": torch.ones(
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
),
},
),
)
}
]
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
@@ -104,8 +111,8 @@ def test_apply_smart_model_patches(
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
assert lora.layers[0][1].up.device.type == "cpu"
assert lora.layers[0][1].down.device.type == "cpu"
output_during_patch = model(input)
@@ -150,20 +157,34 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
lora_layers = [
(
"linear_layer_1",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones(
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
),
"lora_up.weight": torch.ones(
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
),
},
),
),
"linear_layer_2": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
(
"linear_layer_2",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones(
(lora_rank, linear_out_features), device="cpu", dtype=torch.float16
),
"lora_up.weight": torch.ones(
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
),
},
),
),
}
]
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
@@ -204,14 +225,21 @@ def test_all_patching_methods_produce_same_output(num_loras: int):
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
lora_layers = [
(
"linear_layer_1",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones(
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
),
"lora_up.weight": torch.ones(
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
),
},
),
)
}
]
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
@@ -249,14 +277,17 @@ def test_apply_smart_model_patches_change_device():
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
apply_custom_layers_to_model(model)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
lora_layers = [
(
"linear_layer_1",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
),
)
}
]
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
@@ -265,8 +296,8 @@ def test_apply_smart_model_patches_change_device():
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
assert lora_layers[0][1].up.device.type == "cpu"
assert lora_layers[0][1].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
@@ -292,14 +323,17 @@ def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
apply_custom_layers_to_model(model)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
lora_layers = [
(
"linear_layer_1",
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
)
}
]
lora = ModelPatchRaw(lora_layers)
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
with LayerPatcher.apply_smart_model_patches(