FLUX LoRA Support (#6847)

## Summary

This PR adds support for FLUX LoRA models on both quantized and
non-quantized base models.

Supported formats:
- diffusers
- kohya

Full changelist:
- Consolidated LoRA handling code in `invokeai/backend/lora`
- Add support for FLUX kohya and FLUX diffusers LoRA model loading
- Add ability to either patch LoRAs or run as a sidecar model (the
latter enables LoRAs to be applied to a wide range of quantized models).

## QA Instructions

Note to reviewers: I tested everything in this checklist. Feel free to
re-verify any of this, but also test any LoRAs that you have. There are
many small LoRA format variations, and there's a risk of breaking one of
them with this change.

FLUX LoRA
- [x] Import / probe of kohya FLUX LoRA
(https://civitai.com/models/159333/pokemon-trainer-sprite-pixelart?modelVersionId=779247)
- [x] Import / probe of Diffusers FLUX LoRA
(https://civitai.com/models/200255/hands-xl-sd-15-flux1-dev?modelVersionId=781855)
- [x] kohya with non-quantized base model
- [x] kohya with quantized base model (should roughly match the
non-quantized case)
- [x] diffusers with non-quantized base model
- [x] diffusers with quantized base model (should roughly match the
non-quantized case)
- [x] Sidecar LoRA patching speed (<0.1secs after model is loaded)
- [x] Stacking multiple fused LoRA models (i.e. on top on non-quantized
model)
- [x] Stacking multiple sidecar LoRA models (i.e. on top of quantized
model)

Regression Tests
- [x] SD1.5 LoRA (check output, speed and memory)
- [x] SDXL LoRA (check output, speed and memory)
- [x] `USE_MODULAR_DENOISE=1` smoke test with LoRA

Test for output regression with the following LoRA formats:
  - [x] LoRA
  - [x] LoHA
  - [x] LoKr
  - [x] IA3

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick
2024-09-18 13:49:54 -04:00
committed by GitHub
46 changed files with 3822 additions and 852 deletions

View File

@@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -81,9 +82,10 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -176,9 +178,9 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoRAPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),

View File

@@ -37,6 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):

View File

@@ -1,4 +1,5 @@
from typing import Callable, Optional
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
@@ -29,6 +30,9 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -187,9 +191,41 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with transformer_info as transformer:
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
)
)
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
dtype=inference_dtype,
)
)
else:
raise ValueError(f"Unsupported model format: {config.format}")
x = denoise(
model=transformer,
img=x,
@@ -247,6 +283,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()

View File

@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),

View File

@@ -23,7 +23,7 @@ from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@@ -0,0 +1,206 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
all_expected_keys_present = all(k in state_dict for k in expected_keys)
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
layers: dict[str, AnyLoRALayer] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
value = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
value["alpha"] = torch.tensor(alpha)
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayer] = []
for src_layer_dict in src_layer_dicts:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
# time_text_embed.text_embedder -> vector_in.
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
# x_embedder -> img_in.
add_lora_layer_if_present("x_embedder", "img_in")
# Double transformer blocks.
for i in range(num_double_layers):
# norms.
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
# Q, K, V
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.to_q",
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.add_q_proj",
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
f"double_blocks.{i}.txt_attn.qkv",
)
# ff img_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.0.proj",
f"double_blocks.{i}.img_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.2",
f"double_blocks.{i}.img_mlp.2",
)
# ff txt_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.0.proj",
f"double_blocks.{i}.txt_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.2",
f"double_blocks.{i}.txt_mlp.2",
)
# output projections.
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_out.0",
f"double_blocks.{i}.img_attn.proj",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_add_out",
f"double_blocks.{i}.txt_attn.proj",
)
# Single transformer blocks.
for i in range(num_single_layers):
# norms
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.norm.linear",
f"single_blocks.{i}.modulation.lin",
)
# Q, K, V, mlp
add_qkv_lora_layer_if_present(
[
f"single_transformer_blocks.{i}.attn.to_q",
f"single_transformer_blocks.{i}.attn.to_k",
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
f"single_blocks.{i}.linear1",
)
# Output projections.
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.proj_out",
f"single_blocks.{i}.linear2",
)
# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Groups the keys in the state dict by layer."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
for key in state_dict:
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])
if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]
return layer_dict

View File

@@ -0,0 +1,80 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -0,0 +1,29 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -0,0 +1,154 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

@@ -1,5 +1,6 @@
from typing import Union
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
@@ -7,4 +8,4 @@ from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]

View File

@@ -0,0 +1,55 @@
from typing import Optional, Sequence
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = lora_layers
self.concat_axis = concat_axis
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original weight tensor here.
# Note that we must apply the sub-layer scales here.
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self.concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
# Note that we must apply the sub-layer scales here.
layer_biases: list[torch.Tensor] = []
for lora_layer in self.lora_layers:
layer_bias = lora_layer.get_bias(None)
if layer_bias is not None:
layer_biases.append(layer_bias * lora_layer.scale())
if len(layer_biases) == 0:
return None
assert len(layer_biases) == len(self.lora_layers)
return torch.cat(layer_biases, dim=self.concat_axis)
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
for lora_layer in self.lora_layers:
lora_layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + sum(lora_layer.calc_size() for lora_layer in self.lora_layers)

View File

@@ -3,35 +3,32 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
def __init__(
self,
layer_key: str,
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
layer = cls(weight=values["diff"], bias=values.get("diff_b", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
return layer
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@@ -6,37 +6,53 @@ from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
"""IA3 Layer
def __init__(
self,
layer_key: str,
Example model for testing this layer type: https://civitai.com/models/123930/gwendolyn-tennyson-ben-10-ia3
"""
def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = weight
self.on_input = on_input
def rank(self) -> int | None:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
weight=values["weight"],
on_input=values["on_input"],
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"weight",
"on_input",
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device, dtype)
self.weight = self.weight.to(device, dtype)
self.on_input = self.on_input.to(device, dtype)

View File

@@ -1,32 +1,69 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
"""LoHA LyCoris layer.
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
Example model for testing this layer type: https://civitai.com/models/27397/loha-renoir-the-dappled-light-style
"""
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
def __init__(
self,
w1_a: torch.Tensor,
w1_b: torch.Tensor,
w2_a: torch.Tensor,
w2_b: torch.Tensor,
t1: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1_a = w1_a
self.w1_b = w1_b
self.w2_a = w2_a
self.w2_b = w2_b
self.t1 = t1
self.t2 = t2
assert (self.t1 is None) == (self.t2 is None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
def rank(self) -> int | None:
return self.w1_b.shape[0]
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1_a=values["hada_w1_a"],
w1_b=values["hada_w1_b"],
w2_a=values["hada_w2_a"],
w2_b=values["hada_w2_b"],
t1=values.get("hada_t1", None),
t2=values.get("hada_t2", None),
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
@@ -36,10 +73,11 @@ class LoHALayer(LoRALayerBase):
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
@@ -47,22 +85,14 @@ class LoHALayer(LoRALayerBase):
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype) if self.t1 is not None else self.t1
self.t2 = self.t2.to(device=device, dtype=dtype) if self.t2 is not None else self.t2
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2])

View File

@@ -1,54 +1,82 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
"""LoKR LyCoris layer.
Example model for testing this layer type: https://civitai.com/models/346747/lokrnekopara-allgirl-for-jru2
"""
def __init__(
self,
layer_key: str,
w1: torch.Tensor | None,
w1_a: torch.Tensor | None,
w1_b: torch.Tensor | None,
w2: torch.Tensor | None,
w2_a: torch.Tensor | None,
w2_b: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1 = w1
self.w1_a = w1_a
self.w1_b = w1_b
self.w2 = w2
self.w2_a = w2_a
self.w2_b = w2_b
self.t2 = t2
# Validate parameters.
assert (self.w1 is None) != (self.w1_a is None)
assert (self.w1_a is None) == (self.w1_b is None)
assert (self.w2 is None) != (self.w2_a is None)
assert (self.w2_a is None) == (self.w2_b is None)
def rank(self) -> int | None:
if self.w1_b is not None:
return self.w1_b.shape[0]
elif self.w2_b is not None:
return self.w2_b.shape[0]
else:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1=values.get("lokr_w1", None),
w1_a=values.get("lokr_w1_a", None),
w1_b=values.get("lokr_w1_b", None),
w2=values.get("lokr_w2", None),
w2_a=values.get("lokr_w2_a", None),
w2_b=values.get("lokr_w2_b", None),
t2=values.get("lokr_t2", None),
alpha=alpha,
bias=bias,
)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
cls.warn_on_unhandled_keys(
values,
{
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
@@ -59,8 +87,10 @@ class LoKRLayer(LoRALayerBase):
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
w1 = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
@@ -78,37 +108,20 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.w1 = self.w1.to(device=device, dtype=dtype) if self.w1 is not None else self.w1
self.w1_a = self.w1_a.to(device=device, dtype=dtype) if self.w1_a is not None else self.w1_a
self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.w1_b is not None else self.w1_b
self.w2 = self.w2.to(device=device, dtype=dtype) if self.w2 is not None else self.w2
self.w2_a = self.w2_a.to(device=device, dtype=dtype) if self.w2_a is not None else self.w2_a
self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.w2_b is not None else self.w2_b
self.t2 = self.t2.to(device=device, dtype=dtype) if self.t2 is not None else self.t2
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size(
[self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]
)

View File

@@ -3,35 +3,61 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
up: torch.Tensor,
mid: Optional[torch.Tensor],
down: torch.Tensor,
alpha: float | None,
bias: Optional[torch.Tensor],
):
super().__init__(alpha, bias)
self.up = up
self.mid = mid
self.down = down
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
layer = cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
mid=values.get("lora_mid.weight", None),
alpha=alpha,
bias=bias,
)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
return layer
def rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
@@ -42,18 +68,12 @@ class LoRALayer(LoRALayerBase):
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.up, self.mid, self.down])

View File

@@ -3,40 +3,48 @@ from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
"""Base class for all LoRA-like patching layers."""
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
# Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
# could inherit automatic .to(...) behavior for this class, its subclasses, and all sidecar layers that wrap a
# LoRALayerBase. We would also be able to implement a single calc_size() method that could be inherited by all
# subclasses. But, it turns out that the speed overhead of the default .to(...) implementation in torch.nn.Module is
# noticeable, so for now we have opted not to use torch.nn.Module.
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
def __init__(self, alpha: float | None, bias: torch.Tensor | None):
self._alpha = alpha
self.bias = bias
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
@classmethod
def _parse_bias(
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
) -> torch.Tensor | None:
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
else:
self.bias = None
bias = None
if bias_indices is not None:
bias = torch.sparse_coo_tensor(bias_indices, bias_values, tuple(bias_size))
return bias
self.rank = None # set in layer implementation
self.layer_key = layer_key
@classmethod
def _parse_alpha(
cls,
alpha: torch.Tensor | None,
) -> float | None:
return alpha.item() if alpha is not None else None
def rank(self) -> int | None:
raise NotImplementedError()
def scale(self) -> float:
rank = self.rank()
if self._alpha is None or rank is None:
return 1.0
return self._alpha / rank
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@@ -51,24 +59,18 @@ class LoRALayerBase:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
@classmethod
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
unknown_keys = set(values.keys()) - handled_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
)
def calc_size(self) -> int:
return calc_tensors_size([self.bias])
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)

View File

@@ -1,37 +1,34 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(self, weight: torch.Tensor, bias: torch.Tensor | None):
super().__init__(alpha=None, bias=bias)
self.weight = weight
def __init__(
self,
layer_key: str,
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
layer = cls(weight=values["w_norm"], bias=values.get("b_norm", None))
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
return layer
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@@ -0,0 +1,33 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer.from_state_dict_values(state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer.from_state_dict_values(state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer.from_state_dict_values(state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer.from_state_dict_values(state_dict)
elif "on_input" in state_dict:
return IA3Layer.from_state_dict_values(state_dict)
elif "w_norm" in state_dict:
return NormLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -1,43 +1,17 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Mapping, Optional
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
def __init__(self, layers: Mapping[str, AnyLoRALayer]):
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
@@ -46,234 +20,3 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -0,0 +1,302 @@
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model within a context manager.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
CPU RAM, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoRAPatcher.apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@staticmethod
@torch.no_grad()
def apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""Apply a single LoRA patch to a model.
Args:
model (torch.nn.Module): The model to patch.
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
patch (LoRAModelRaw): The LoRA model to patch in.
patch_weight (float): The weight of the LoRA patch.
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.scale()
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_lora_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# Replace the original module with a LoRASidecarModule if it has not already been done.
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
lora_sidecar_module = module
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [])
original_modules[module_key] = module
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
# Add the LoRA sidecar layer to the LoRASidecarModule.
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
"""Split a module key into its parent key and module name.
Args:
module_key (str): The module key to split.
Returns:
tuple[str, str]: A tuple containing the parent key and module name.
"""
split_key = module_key.rsplit(".", 1)
if len(split_key) == 2:
return tuple(split_key)
elif len(split_key) == 1:
return "", split_key[0]
else:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear) or (
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule # type: ignore
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
Args:
model (torch.nn.Module): The model to search.
layer_key (str): The layer key to search for.
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
directly without searching, but some legacy code still uses flattened keys.
Returns:
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module

View File

@@ -0,0 +1,34 @@
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
concatenated_lora_layer: ConcatenatedLoRALayer,
weight: float,
):
super().__init__()
self._concatenated_lora_layer = concatenated_lora_layer
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= self._weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert self._concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._concatenated_lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -0,0 +1,27 @@
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
lora_layer: LoRALayer,
weight: float,
):
super().__init__()
self._lora_layer = lora_layer
self._weight = weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.linear(x, self._lora_layer.down)
if self._lora_layer.mid is not None:
x = torch.nn.functional.linear(x, self._lora_layer.mid)
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
x *= self._weight * self._lora_layer.scale()
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -0,0 +1,24 @@
import torch
class LoRASidecarModule(torch.nn.Module):
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self.orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.orig_module(input)
for lora_layer in self._lora_layers:
x += lora_layer(input)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._orig_module.to(device=device, dtype=dtype)
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -5,8 +5,18 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -45,14 +55,38 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format == ModelFormat.Diffusers:
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
# is a popular choice. For example, in the diffusers training scripts:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
model.to(dtype=self._torch_dtype)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base

View File

@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
@@ -83,10 +84,9 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
mem_params = sum([calc_tensor_size(param) for param in model.parameters()])
mem_bufs = sum([calc_tensor_size(buf) for buf in model.buffers()])
return mem_params + mem_bufs
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:

View File

@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -244,7 +248,9 @@ class ModelProbe(object):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
@@ -554,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
# file, but the weight keys are in the diffusers format.
return ModelFormat.Diffusers
return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
self.checkpoint
):
return BaseModelType.Flux
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:

View File

@@ -5,32 +5,18 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
class ModelPatcher:
@@ -54,95 +40,6 @@ class ModelPatcher:
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
cached_weights=cached_weights,
):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
yield
finally:
with torch.no_grad():
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager
def apply_ti(
@@ -282,26 +179,6 @@ class ModelPatcher:
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@@ -10,6 +10,7 @@ from transformers import CLIPTokenizer
from typing_extensions import Self
from invokeai.backend.raw_model import RawModel
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class TextualInversionModelRaw(RawModel):
@@ -74,11 +75,7 @@ class TextualInversionModelRaw(RawModel):
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
embedding_size = self.embedding.element_size() * self.embedding.nelement()
embedding_2_size = 0
if self.embedding_2 is not None:
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
return embedding_size + embedding_2_size
return calc_tensors_size([self.embedding, self.embedding_2])
class TextualInversionManager(BaseTextualInversionManager):

View File

@@ -0,0 +1,11 @@
import torch
def calc_tensor_size(t: torch.Tensor) -> int:
"""Calculate the size of a tensor in bytes."""
return t.nelement() * t.element_size()
def calc_tensors_size(tensors: list[torch.Tensor | None]) -> int:
"""Calculate the size of a list of tensors in bytes."""
return sum(calc_tensor_size(t) for t in tensors if t is not None)

View File

@@ -0,0 +1,993 @@
# A sample state dict in the Diffusers FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/200255/hands-xl-sd-15-flux1-dev?modelVersionId=781855
state_dict_keys = [
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_B.weight",
]

View File

@@ -0,0 +1,917 @@
# A sample state dict in the Kohya FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/159333/pokemon-trainer-sprite-pixelart?modelVersionId=779247
state_dict_keys = [
"lora_unet_double_blocks_0_img_attn_proj.alpha",
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_img_attn_qkv.alpha",
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_0.alpha",
"lora_unet_double_blocks_0_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_2.alpha",
"lora_unet_double_blocks_0_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_img_mod_lin.alpha",
"lora_unet_double_blocks_0_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_proj.alpha",
"lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.alpha",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_0.alpha",
"lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_2.alpha",
"lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_txt_mod_lin.alpha",
"lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_proj.alpha",
"lora_unet_double_blocks_10_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_qkv.alpha",
"lora_unet_double_blocks_10_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_0.alpha",
"lora_unet_double_blocks_10_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_2.alpha",
"lora_unet_double_blocks_10_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_img_mod_lin.alpha",
"lora_unet_double_blocks_10_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_proj.alpha",
"lora_unet_double_blocks_10_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.alpha",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_0.alpha",
"lora_unet_double_blocks_10_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_2.alpha",
"lora_unet_double_blocks_10_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_txt_mod_lin.alpha",
"lora_unet_double_blocks_10_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_proj.alpha",
"lora_unet_double_blocks_11_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_qkv.alpha",
"lora_unet_double_blocks_11_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_0.alpha",
"lora_unet_double_blocks_11_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_2.alpha",
"lora_unet_double_blocks_11_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_img_mod_lin.alpha",
"lora_unet_double_blocks_11_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_proj.alpha",
"lora_unet_double_blocks_11_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.alpha",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_0.alpha",
"lora_unet_double_blocks_11_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_2.alpha",
"lora_unet_double_blocks_11_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_txt_mod_lin.alpha",
"lora_unet_double_blocks_11_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_proj.alpha",
"lora_unet_double_blocks_12_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_qkv.alpha",
"lora_unet_double_blocks_12_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_0.alpha",
"lora_unet_double_blocks_12_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_2.alpha",
"lora_unet_double_blocks_12_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_img_mod_lin.alpha",
"lora_unet_double_blocks_12_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_proj.alpha",
"lora_unet_double_blocks_12_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.alpha",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_0.alpha",
"lora_unet_double_blocks_12_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_2.alpha",
"lora_unet_double_blocks_12_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_txt_mod_lin.alpha",
"lora_unet_double_blocks_12_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_proj.alpha",
"lora_unet_double_blocks_13_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_qkv.alpha",
"lora_unet_double_blocks_13_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_0.alpha",
"lora_unet_double_blocks_13_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_2.alpha",
"lora_unet_double_blocks_13_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_img_mod_lin.alpha",
"lora_unet_double_blocks_13_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_proj.alpha",
"lora_unet_double_blocks_13_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.alpha",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_0.alpha",
"lora_unet_double_blocks_13_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_2.alpha",
"lora_unet_double_blocks_13_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_txt_mod_lin.alpha",
"lora_unet_double_blocks_13_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_proj.alpha",
"lora_unet_double_blocks_14_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_qkv.alpha",
"lora_unet_double_blocks_14_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_0.alpha",
"lora_unet_double_blocks_14_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_2.alpha",
"lora_unet_double_blocks_14_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_img_mod_lin.alpha",
"lora_unet_double_blocks_14_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_proj.alpha",
"lora_unet_double_blocks_14_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.alpha",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_0.alpha",
"lora_unet_double_blocks_14_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_2.alpha",
"lora_unet_double_blocks_14_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_txt_mod_lin.alpha",
"lora_unet_double_blocks_14_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_proj.alpha",
"lora_unet_double_blocks_15_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_qkv.alpha",
"lora_unet_double_blocks_15_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_0.alpha",
"lora_unet_double_blocks_15_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_2.alpha",
"lora_unet_double_blocks_15_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_img_mod_lin.alpha",
"lora_unet_double_blocks_15_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_proj.alpha",
"lora_unet_double_blocks_15_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.alpha",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_0.alpha",
"lora_unet_double_blocks_15_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_2.alpha",
"lora_unet_double_blocks_15_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_txt_mod_lin.alpha",
"lora_unet_double_blocks_15_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_proj.alpha",
"lora_unet_double_blocks_16_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_qkv.alpha",
"lora_unet_double_blocks_16_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_0.alpha",
"lora_unet_double_blocks_16_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_2.alpha",
"lora_unet_double_blocks_16_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_img_mod_lin.alpha",
"lora_unet_double_blocks_16_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_proj.alpha",
"lora_unet_double_blocks_16_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.alpha",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_0.alpha",
"lora_unet_double_blocks_16_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_2.alpha",
"lora_unet_double_blocks_16_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_txt_mod_lin.alpha",
"lora_unet_double_blocks_16_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_proj.alpha",
"lora_unet_double_blocks_17_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_qkv.alpha",
"lora_unet_double_blocks_17_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_0.alpha",
"lora_unet_double_blocks_17_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_2.alpha",
"lora_unet_double_blocks_17_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_img_mod_lin.alpha",
"lora_unet_double_blocks_17_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_proj.alpha",
"lora_unet_double_blocks_17_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.alpha",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_0.alpha",
"lora_unet_double_blocks_17_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_2.alpha",
"lora_unet_double_blocks_17_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_txt_mod_lin.alpha",
"lora_unet_double_blocks_17_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_proj.alpha",
"lora_unet_double_blocks_18_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_qkv.alpha",
"lora_unet_double_blocks_18_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_0.alpha",
"lora_unet_double_blocks_18_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_2.alpha",
"lora_unet_double_blocks_18_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_img_mod_lin.alpha",
"lora_unet_double_blocks_18_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_proj.alpha",
"lora_unet_double_blocks_18_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.alpha",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_0.alpha",
"lora_unet_double_blocks_18_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_2.alpha",
"lora_unet_double_blocks_18_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_txt_mod_lin.alpha",
"lora_unet_double_blocks_18_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_proj.alpha",
"lora_unet_double_blocks_1_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_qkv.alpha",
"lora_unet_double_blocks_1_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_0.alpha",
"lora_unet_double_blocks_1_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_2.alpha",
"lora_unet_double_blocks_1_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_img_mod_lin.alpha",
"lora_unet_double_blocks_1_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_proj.alpha",
"lora_unet_double_blocks_1_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.alpha",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_0.alpha",
"lora_unet_double_blocks_1_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_2.alpha",
"lora_unet_double_blocks_1_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_txt_mod_lin.alpha",
"lora_unet_double_blocks_1_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_proj.alpha",
"lora_unet_double_blocks_2_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_qkv.alpha",
"lora_unet_double_blocks_2_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_0.alpha",
"lora_unet_double_blocks_2_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_2.alpha",
"lora_unet_double_blocks_2_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_img_mod_lin.alpha",
"lora_unet_double_blocks_2_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_proj.alpha",
"lora_unet_double_blocks_2_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.alpha",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_0.alpha",
"lora_unet_double_blocks_2_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_2.alpha",
"lora_unet_double_blocks_2_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_txt_mod_lin.alpha",
"lora_unet_double_blocks_2_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_proj.alpha",
"lora_unet_double_blocks_3_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_qkv.alpha",
"lora_unet_double_blocks_3_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_0.alpha",
"lora_unet_double_blocks_3_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_2.alpha",
"lora_unet_double_blocks_3_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_img_mod_lin.alpha",
"lora_unet_double_blocks_3_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_proj.alpha",
"lora_unet_double_blocks_3_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.alpha",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_0.alpha",
"lora_unet_double_blocks_3_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_2.alpha",
"lora_unet_double_blocks_3_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_txt_mod_lin.alpha",
"lora_unet_double_blocks_3_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_proj.alpha",
"lora_unet_double_blocks_4_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_qkv.alpha",
"lora_unet_double_blocks_4_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_0.alpha",
"lora_unet_double_blocks_4_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_2.alpha",
"lora_unet_double_blocks_4_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_img_mod_lin.alpha",
"lora_unet_double_blocks_4_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_proj.alpha",
"lora_unet_double_blocks_4_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.alpha",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_0.alpha",
"lora_unet_double_blocks_4_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_2.alpha",
"lora_unet_double_blocks_4_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_txt_mod_lin.alpha",
"lora_unet_double_blocks_4_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_proj.alpha",
"lora_unet_double_blocks_5_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_qkv.alpha",
"lora_unet_double_blocks_5_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_0.alpha",
"lora_unet_double_blocks_5_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_2.alpha",
"lora_unet_double_blocks_5_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_img_mod_lin.alpha",
"lora_unet_double_blocks_5_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_proj.alpha",
"lora_unet_double_blocks_5_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.alpha",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_0.alpha",
"lora_unet_double_blocks_5_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_2.alpha",
"lora_unet_double_blocks_5_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_txt_mod_lin.alpha",
"lora_unet_double_blocks_5_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_proj.alpha",
"lora_unet_double_blocks_6_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_qkv.alpha",
"lora_unet_double_blocks_6_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_0.alpha",
"lora_unet_double_blocks_6_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_2.alpha",
"lora_unet_double_blocks_6_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_img_mod_lin.alpha",
"lora_unet_double_blocks_6_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_proj.alpha",
"lora_unet_double_blocks_6_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.alpha",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_0.alpha",
"lora_unet_double_blocks_6_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_2.alpha",
"lora_unet_double_blocks_6_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_txt_mod_lin.alpha",
"lora_unet_double_blocks_6_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_proj.alpha",
"lora_unet_double_blocks_7_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_qkv.alpha",
"lora_unet_double_blocks_7_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_0.alpha",
"lora_unet_double_blocks_7_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_2.alpha",
"lora_unet_double_blocks_7_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_img_mod_lin.alpha",
"lora_unet_double_blocks_7_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_proj.alpha",
"lora_unet_double_blocks_7_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.alpha",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_0.alpha",
"lora_unet_double_blocks_7_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_2.alpha",
"lora_unet_double_blocks_7_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_txt_mod_lin.alpha",
"lora_unet_double_blocks_7_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_proj.alpha",
"lora_unet_double_blocks_8_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_qkv.alpha",
"lora_unet_double_blocks_8_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_0.alpha",
"lora_unet_double_blocks_8_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_2.alpha",
"lora_unet_double_blocks_8_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_img_mod_lin.alpha",
"lora_unet_double_blocks_8_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_proj.alpha",
"lora_unet_double_blocks_8_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.alpha",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_0.alpha",
"lora_unet_double_blocks_8_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_2.alpha",
"lora_unet_double_blocks_8_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_txt_mod_lin.alpha",
"lora_unet_double_blocks_8_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_proj.alpha",
"lora_unet_double_blocks_9_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_qkv.alpha",
"lora_unet_double_blocks_9_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_0.alpha",
"lora_unet_double_blocks_9_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_2.alpha",
"lora_unet_double_blocks_9_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_img_mod_lin.alpha",
"lora_unet_double_blocks_9_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_proj.alpha",
"lora_unet_double_blocks_9_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.alpha",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_0.alpha",
"lora_unet_double_blocks_9_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_2.alpha",
"lora_unet_double_blocks_9_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_txt_mod_lin.alpha",
"lora_unet_double_blocks_9_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_txt_mod_lin.lora_up.weight",
"lora_unet_single_blocks_0_linear1.alpha",
"lora_unet_single_blocks_0_linear1.lora_down.weight",
"lora_unet_single_blocks_0_linear1.lora_up.weight",
"lora_unet_single_blocks_0_linear2.alpha",
"lora_unet_single_blocks_0_linear2.lora_down.weight",
"lora_unet_single_blocks_0_linear2.lora_up.weight",
"lora_unet_single_blocks_0_modulation_lin.alpha",
"lora_unet_single_blocks_0_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_0_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_10_linear1.alpha",
"lora_unet_single_blocks_10_linear1.lora_down.weight",
"lora_unet_single_blocks_10_linear1.lora_up.weight",
"lora_unet_single_blocks_10_linear2.alpha",
"lora_unet_single_blocks_10_linear2.lora_down.weight",
"lora_unet_single_blocks_10_linear2.lora_up.weight",
"lora_unet_single_blocks_10_modulation_lin.alpha",
"lora_unet_single_blocks_10_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_10_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_11_linear1.alpha",
"lora_unet_single_blocks_11_linear1.lora_down.weight",
"lora_unet_single_blocks_11_linear1.lora_up.weight",
"lora_unet_single_blocks_11_linear2.alpha",
"lora_unet_single_blocks_11_linear2.lora_down.weight",
"lora_unet_single_blocks_11_linear2.lora_up.weight",
"lora_unet_single_blocks_11_modulation_lin.alpha",
"lora_unet_single_blocks_11_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_11_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_12_linear1.alpha",
"lora_unet_single_blocks_12_linear1.lora_down.weight",
"lora_unet_single_blocks_12_linear1.lora_up.weight",
"lora_unet_single_blocks_12_linear2.alpha",
"lora_unet_single_blocks_12_linear2.lora_down.weight",
"lora_unet_single_blocks_12_linear2.lora_up.weight",
"lora_unet_single_blocks_12_modulation_lin.alpha",
"lora_unet_single_blocks_12_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_12_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_13_linear1.alpha",
"lora_unet_single_blocks_13_linear1.lora_down.weight",
"lora_unet_single_blocks_13_linear1.lora_up.weight",
"lora_unet_single_blocks_13_linear2.alpha",
"lora_unet_single_blocks_13_linear2.lora_down.weight",
"lora_unet_single_blocks_13_linear2.lora_up.weight",
"lora_unet_single_blocks_13_modulation_lin.alpha",
"lora_unet_single_blocks_13_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_13_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_14_linear1.alpha",
"lora_unet_single_blocks_14_linear1.lora_down.weight",
"lora_unet_single_blocks_14_linear1.lora_up.weight",
"lora_unet_single_blocks_14_linear2.alpha",
"lora_unet_single_blocks_14_linear2.lora_down.weight",
"lora_unet_single_blocks_14_linear2.lora_up.weight",
"lora_unet_single_blocks_14_modulation_lin.alpha",
"lora_unet_single_blocks_14_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_14_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_15_linear1.alpha",
"lora_unet_single_blocks_15_linear1.lora_down.weight",
"lora_unet_single_blocks_15_linear1.lora_up.weight",
"lora_unet_single_blocks_15_linear2.alpha",
"lora_unet_single_blocks_15_linear2.lora_down.weight",
"lora_unet_single_blocks_15_linear2.lora_up.weight",
"lora_unet_single_blocks_15_modulation_lin.alpha",
"lora_unet_single_blocks_15_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_15_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_16_linear1.alpha",
"lora_unet_single_blocks_16_linear1.lora_down.weight",
"lora_unet_single_blocks_16_linear1.lora_up.weight",
"lora_unet_single_blocks_16_linear2.alpha",
"lora_unet_single_blocks_16_linear2.lora_down.weight",
"lora_unet_single_blocks_16_linear2.lora_up.weight",
"lora_unet_single_blocks_16_modulation_lin.alpha",
"lora_unet_single_blocks_16_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_16_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_17_linear1.alpha",
"lora_unet_single_blocks_17_linear1.lora_down.weight",
"lora_unet_single_blocks_17_linear1.lora_up.weight",
"lora_unet_single_blocks_17_linear2.alpha",
"lora_unet_single_blocks_17_linear2.lora_down.weight",
"lora_unet_single_blocks_17_linear2.lora_up.weight",
"lora_unet_single_blocks_17_modulation_lin.alpha",
"lora_unet_single_blocks_17_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_17_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_18_linear1.alpha",
"lora_unet_single_blocks_18_linear1.lora_down.weight",
"lora_unet_single_blocks_18_linear1.lora_up.weight",
"lora_unet_single_blocks_18_linear2.alpha",
"lora_unet_single_blocks_18_linear2.lora_down.weight",
"lora_unet_single_blocks_18_linear2.lora_up.weight",
"lora_unet_single_blocks_18_modulation_lin.alpha",
"lora_unet_single_blocks_18_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_18_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_19_linear1.alpha",
"lora_unet_single_blocks_19_linear1.lora_down.weight",
"lora_unet_single_blocks_19_linear1.lora_up.weight",
"lora_unet_single_blocks_19_linear2.alpha",
"lora_unet_single_blocks_19_linear2.lora_down.weight",
"lora_unet_single_blocks_19_linear2.lora_up.weight",
"lora_unet_single_blocks_19_modulation_lin.alpha",
"lora_unet_single_blocks_19_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_19_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_1_linear1.alpha",
"lora_unet_single_blocks_1_linear1.lora_down.weight",
"lora_unet_single_blocks_1_linear1.lora_up.weight",
"lora_unet_single_blocks_1_linear2.alpha",
"lora_unet_single_blocks_1_linear2.lora_down.weight",
"lora_unet_single_blocks_1_linear2.lora_up.weight",
"lora_unet_single_blocks_1_modulation_lin.alpha",
"lora_unet_single_blocks_1_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_1_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_20_linear1.alpha",
"lora_unet_single_blocks_20_linear1.lora_down.weight",
"lora_unet_single_blocks_20_linear1.lora_up.weight",
"lora_unet_single_blocks_20_linear2.alpha",
"lora_unet_single_blocks_20_linear2.lora_down.weight",
"lora_unet_single_blocks_20_linear2.lora_up.weight",
"lora_unet_single_blocks_20_modulation_lin.alpha",
"lora_unet_single_blocks_20_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_20_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_21_linear1.alpha",
"lora_unet_single_blocks_21_linear1.lora_down.weight",
"lora_unet_single_blocks_21_linear1.lora_up.weight",
"lora_unet_single_blocks_21_linear2.alpha",
"lora_unet_single_blocks_21_linear2.lora_down.weight",
"lora_unet_single_blocks_21_linear2.lora_up.weight",
"lora_unet_single_blocks_21_modulation_lin.alpha",
"lora_unet_single_blocks_21_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_21_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_22_linear1.alpha",
"lora_unet_single_blocks_22_linear1.lora_down.weight",
"lora_unet_single_blocks_22_linear1.lora_up.weight",
"lora_unet_single_blocks_22_linear2.alpha",
"lora_unet_single_blocks_22_linear2.lora_down.weight",
"lora_unet_single_blocks_22_linear2.lora_up.weight",
"lora_unet_single_blocks_22_modulation_lin.alpha",
"lora_unet_single_blocks_22_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_22_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_23_linear1.alpha",
"lora_unet_single_blocks_23_linear1.lora_down.weight",
"lora_unet_single_blocks_23_linear1.lora_up.weight",
"lora_unet_single_blocks_23_linear2.alpha",
"lora_unet_single_blocks_23_linear2.lora_down.weight",
"lora_unet_single_blocks_23_linear2.lora_up.weight",
"lora_unet_single_blocks_23_modulation_lin.alpha",
"lora_unet_single_blocks_23_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_23_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_24_linear1.alpha",
"lora_unet_single_blocks_24_linear1.lora_down.weight",
"lora_unet_single_blocks_24_linear1.lora_up.weight",
"lora_unet_single_blocks_24_linear2.alpha",
"lora_unet_single_blocks_24_linear2.lora_down.weight",
"lora_unet_single_blocks_24_linear2.lora_up.weight",
"lora_unet_single_blocks_24_modulation_lin.alpha",
"lora_unet_single_blocks_24_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_24_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_25_linear1.alpha",
"lora_unet_single_blocks_25_linear1.lora_down.weight",
"lora_unet_single_blocks_25_linear1.lora_up.weight",
"lora_unet_single_blocks_25_linear2.alpha",
"lora_unet_single_blocks_25_linear2.lora_down.weight",
"lora_unet_single_blocks_25_linear2.lora_up.weight",
"lora_unet_single_blocks_25_modulation_lin.alpha",
"lora_unet_single_blocks_25_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_25_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_26_linear1.alpha",
"lora_unet_single_blocks_26_linear1.lora_down.weight",
"lora_unet_single_blocks_26_linear1.lora_up.weight",
"lora_unet_single_blocks_26_linear2.alpha",
"lora_unet_single_blocks_26_linear2.lora_down.weight",
"lora_unet_single_blocks_26_linear2.lora_up.weight",
"lora_unet_single_blocks_26_modulation_lin.alpha",
"lora_unet_single_blocks_26_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_26_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_27_linear1.alpha",
"lora_unet_single_blocks_27_linear1.lora_down.weight",
"lora_unet_single_blocks_27_linear1.lora_up.weight",
"lora_unet_single_blocks_27_linear2.alpha",
"lora_unet_single_blocks_27_linear2.lora_down.weight",
"lora_unet_single_blocks_27_linear2.lora_up.weight",
"lora_unet_single_blocks_27_modulation_lin.alpha",
"lora_unet_single_blocks_27_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_27_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_28_linear1.alpha",
"lora_unet_single_blocks_28_linear1.lora_down.weight",
"lora_unet_single_blocks_28_linear1.lora_up.weight",
"lora_unet_single_blocks_28_linear2.alpha",
"lora_unet_single_blocks_28_linear2.lora_down.weight",
"lora_unet_single_blocks_28_linear2.lora_up.weight",
"lora_unet_single_blocks_28_modulation_lin.alpha",
"lora_unet_single_blocks_28_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_28_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_29_linear1.alpha",
"lora_unet_single_blocks_29_linear1.lora_down.weight",
"lora_unet_single_blocks_29_linear1.lora_up.weight",
"lora_unet_single_blocks_29_linear2.alpha",
"lora_unet_single_blocks_29_linear2.lora_down.weight",
"lora_unet_single_blocks_29_linear2.lora_up.weight",
"lora_unet_single_blocks_29_modulation_lin.alpha",
"lora_unet_single_blocks_29_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_29_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_2_linear1.alpha",
"lora_unet_single_blocks_2_linear1.lora_down.weight",
"lora_unet_single_blocks_2_linear1.lora_up.weight",
"lora_unet_single_blocks_2_linear2.alpha",
"lora_unet_single_blocks_2_linear2.lora_down.weight",
"lora_unet_single_blocks_2_linear2.lora_up.weight",
"lora_unet_single_blocks_2_modulation_lin.alpha",
"lora_unet_single_blocks_2_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_2_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_30_linear1.alpha",
"lora_unet_single_blocks_30_linear1.lora_down.weight",
"lora_unet_single_blocks_30_linear1.lora_up.weight",
"lora_unet_single_blocks_30_linear2.alpha",
"lora_unet_single_blocks_30_linear2.lora_down.weight",
"lora_unet_single_blocks_30_linear2.lora_up.weight",
"lora_unet_single_blocks_30_modulation_lin.alpha",
"lora_unet_single_blocks_30_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_30_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_31_linear1.alpha",
"lora_unet_single_blocks_31_linear1.lora_down.weight",
"lora_unet_single_blocks_31_linear1.lora_up.weight",
"lora_unet_single_blocks_31_linear2.alpha",
"lora_unet_single_blocks_31_linear2.lora_down.weight",
"lora_unet_single_blocks_31_linear2.lora_up.weight",
"lora_unet_single_blocks_31_modulation_lin.alpha",
"lora_unet_single_blocks_31_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_31_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_32_linear1.alpha",
"lora_unet_single_blocks_32_linear1.lora_down.weight",
"lora_unet_single_blocks_32_linear1.lora_up.weight",
"lora_unet_single_blocks_32_linear2.alpha",
"lora_unet_single_blocks_32_linear2.lora_down.weight",
"lora_unet_single_blocks_32_linear2.lora_up.weight",
"lora_unet_single_blocks_32_modulation_lin.alpha",
"lora_unet_single_blocks_32_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_32_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_33_linear1.alpha",
"lora_unet_single_blocks_33_linear1.lora_down.weight",
"lora_unet_single_blocks_33_linear1.lora_up.weight",
"lora_unet_single_blocks_33_linear2.alpha",
"lora_unet_single_blocks_33_linear2.lora_down.weight",
"lora_unet_single_blocks_33_linear2.lora_up.weight",
"lora_unet_single_blocks_33_modulation_lin.alpha",
"lora_unet_single_blocks_33_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_33_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_34_linear1.alpha",
"lora_unet_single_blocks_34_linear1.lora_down.weight",
"lora_unet_single_blocks_34_linear1.lora_up.weight",
"lora_unet_single_blocks_34_linear2.alpha",
"lora_unet_single_blocks_34_linear2.lora_down.weight",
"lora_unet_single_blocks_34_linear2.lora_up.weight",
"lora_unet_single_blocks_34_modulation_lin.alpha",
"lora_unet_single_blocks_34_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_34_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_35_linear1.alpha",
"lora_unet_single_blocks_35_linear1.lora_down.weight",
"lora_unet_single_blocks_35_linear1.lora_up.weight",
"lora_unet_single_blocks_35_linear2.alpha",
"lora_unet_single_blocks_35_linear2.lora_down.weight",
"lora_unet_single_blocks_35_linear2.lora_up.weight",
"lora_unet_single_blocks_35_modulation_lin.alpha",
"lora_unet_single_blocks_35_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_35_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_36_linear1.alpha",
"lora_unet_single_blocks_36_linear1.lora_down.weight",
"lora_unet_single_blocks_36_linear1.lora_up.weight",
"lora_unet_single_blocks_36_linear2.alpha",
"lora_unet_single_blocks_36_linear2.lora_down.weight",
"lora_unet_single_blocks_36_linear2.lora_up.weight",
"lora_unet_single_blocks_36_modulation_lin.alpha",
"lora_unet_single_blocks_36_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_36_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_37_linear1.alpha",
"lora_unet_single_blocks_37_linear1.lora_down.weight",
"lora_unet_single_blocks_37_linear1.lora_up.weight",
"lora_unet_single_blocks_37_linear2.alpha",
"lora_unet_single_blocks_37_linear2.lora_down.weight",
"lora_unet_single_blocks_37_linear2.lora_up.weight",
"lora_unet_single_blocks_37_modulation_lin.alpha",
"lora_unet_single_blocks_37_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_37_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_3_linear1.alpha",
"lora_unet_single_blocks_3_linear1.lora_down.weight",
"lora_unet_single_blocks_3_linear1.lora_up.weight",
"lora_unet_single_blocks_3_linear2.alpha",
"lora_unet_single_blocks_3_linear2.lora_down.weight",
"lora_unet_single_blocks_3_linear2.lora_up.weight",
"lora_unet_single_blocks_3_modulation_lin.alpha",
"lora_unet_single_blocks_3_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_3_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_4_linear1.alpha",
"lora_unet_single_blocks_4_linear1.lora_down.weight",
"lora_unet_single_blocks_4_linear1.lora_up.weight",
"lora_unet_single_blocks_4_linear2.alpha",
"lora_unet_single_blocks_4_linear2.lora_down.weight",
"lora_unet_single_blocks_4_linear2.lora_up.weight",
"lora_unet_single_blocks_4_modulation_lin.alpha",
"lora_unet_single_blocks_4_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_4_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_5_linear1.alpha",
"lora_unet_single_blocks_5_linear1.lora_down.weight",
"lora_unet_single_blocks_5_linear1.lora_up.weight",
"lora_unet_single_blocks_5_linear2.alpha",
"lora_unet_single_blocks_5_linear2.lora_down.weight",
"lora_unet_single_blocks_5_linear2.lora_up.weight",
"lora_unet_single_blocks_5_modulation_lin.alpha",
"lora_unet_single_blocks_5_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_5_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_6_linear1.alpha",
"lora_unet_single_blocks_6_linear1.lora_down.weight",
"lora_unet_single_blocks_6_linear1.lora_up.weight",
"lora_unet_single_blocks_6_linear2.alpha",
"lora_unet_single_blocks_6_linear2.lora_down.weight",
"lora_unet_single_blocks_6_linear2.lora_up.weight",
"lora_unet_single_blocks_6_modulation_lin.alpha",
"lora_unet_single_blocks_6_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_6_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_7_linear1.alpha",
"lora_unet_single_blocks_7_linear1.lora_down.weight",
"lora_unet_single_blocks_7_linear1.lora_up.weight",
"lora_unet_single_blocks_7_linear2.alpha",
"lora_unet_single_blocks_7_linear2.lora_down.weight",
"lora_unet_single_blocks_7_linear2.lora_up.weight",
"lora_unet_single_blocks_7_modulation_lin.alpha",
"lora_unet_single_blocks_7_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_7_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_8_linear1.alpha",
"lora_unet_single_blocks_8_linear1.lora_down.weight",
"lora_unet_single_blocks_8_linear1.lora_up.weight",
"lora_unet_single_blocks_8_linear2.alpha",
"lora_unet_single_blocks_8_linear2.lora_down.weight",
"lora_unet_single_blocks_8_linear2.lora_up.weight",
"lora_unet_single_blocks_8_modulation_lin.alpha",
"lora_unet_single_blocks_8_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_8_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_9_linear1.alpha",
"lora_unet_single_blocks_9_linear1.lora_down.weight",
"lora_unet_single_blocks_9_linear1.lora_up.weight",
"lora_unet_single_blocks_9_linear2.alpha",
"lora_unet_single_blocks_9_linear2.lora_down.weight",
"lora_unet_single_blocks_9_linear2.lora_up.weight",
"lora_unet_single_blocks_9_modulation_lin.alpha",
"lora_unet_single_blocks_9_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_9_modulation_lin.lora_up.weight",
]

View File

@@ -0,0 +1,8 @@
import torch
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
state_dict: dict[str, torch.Tensor] = {}
for k in keys:
state_dict[k] = torch.empty(1)
return state_dict

View File

@@ -0,0 +1,66 @@
import pytest
import torch
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_diffusers_format_true():
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_lora_model_from_flux_diffusers_state_dict():
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in flux_diffusers_state_dict_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_diffusers_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)

View File

@@ -0,0 +1,99 @@
import accelerate
import pytest
import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true():
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(params["flux-dev"])
model_keys = set(model.state_dict().keys())
# Assert that the converted state dict matches the keys in the actual model.
for converted_key_prefix in converted_key_prefixes:
found_match = False
for model_key in model_keys:
if model_key.startswith(converted_key_prefix):
found_match = True
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in flux_kohya_state_dict_keys:
k = k.replace("lora_unet_", "")
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
expected_layer_keys.add(k)
# Assert that the lora_model has the expected layers.
lora_model_keys = set(lora_model.layers.keys())
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
assert lora_model_keys == expected_layer_keys

View File

@@ -0,0 +1,49 @@
import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
def test_concatenated_lora_linear_sidecar_layer():
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
layer.
"""
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a ConcatenatedLoRALinearSidecarLayer.
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
# equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -0,0 +1,38 @@
import copy
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
@torch.no_grad()
def test_lora_linear_sidecar_layer():
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LoRALinearSidecarLayer.
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -0,0 +1,195 @@
import pytest
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
class DummyModule(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_1(x)
@pytest.mark.parametrize(
["device", "num_layers"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
correct result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_patches_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model.linear_layer_1.weight.data.device.type == "cuda"
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
@pytest.mark.parametrize(
["device", "num_layers"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches)

View File

@@ -1,103 +0,0 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora(device):
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora_weight = 0.5
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model["linear_layer_1"].weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model["linear_layer_1"].weight.data.device.type == "cuda"
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)