Compare commits

...

31 Commits

Author SHA1 Message Date
Ryan Dick
e0c2b13558 Remove unused layer_key property from LoRALayerBase. 2024-09-10 15:17:12 +00:00
Ryan Dick
475e92e305 Consolidate all LoRA patching logic in the LoraPatcher. 2024-09-10 15:09:46 +00:00
Ryan Dick
99a12fb604 Rename peft -> lora in a bunch of places. 2024-09-10 14:00:41 +00:00
Ryan Dick
1a12c48f6e Rename lora.py -> lora_model_raw.py. 2024-09-10 13:54:17 +00:00
Ryan Dick
f1c8fd16b5 Rename peft/ -> lora/ 2024-09-10 13:52:06 +00:00
Ryan Dick
f44a6a7fb2 Genera cleanup/documentation. 2024-09-09 21:52:05 +00:00
Ryan Dick
da780c2243 Add a check that all keys are handled in the FLUX Diffusers LoRA loading code. 2024-09-09 21:13:04 +00:00
Ryan Dick
742f6781d5 Add model probe support for FLUX LoRA models in Diffusers format. 2024-09-09 21:00:40 +00:00
Ryan Dick
d2ffabf276 Add utility test function for creating a dummy state_dict. 2024-09-09 21:00:40 +00:00
Ryan Dick
39e28d5e24 Add is_state_dict_likely_in_flux_diffusers_format(...) function with unit test. 2024-09-09 21:00:40 +00:00
Ryan Dick
3f7c233f4d Add unit test for lora_model_from_flux_diffusers_state_dict(...). 2024-09-09 21:00:40 +00:00
Ryan Dick
af75a8ea99 First draft of lora_model_from_flux_diffusers_state_dict(...). 2024-09-09 21:00:40 +00:00
Ryan Dick
ddfe57e648 (minor) Rename test file. 2024-09-09 21:00:40 +00:00
Ryan Dick
7e8dd9e8ed Add ConcatenateLoRALayer class. 2024-09-09 21:00:40 +00:00
Ryan Dick
50f8d6db1b WIP on supporting diffusers format FLUX LoRAs. 2024-09-09 21:00:40 +00:00
Ryan Dick
961fbe1ba4 Rename flux_kohya_lora_conversion_utils.py 2024-09-09 21:00:40 +00:00
Ryan Dick
7342d18734 Fixup FLUX LoRA unit tests. 2024-09-09 21:00:40 +00:00
Ryan Dick
554a4dc592 WIP 2024-09-09 21:00:40 +00:00
Ryan Dick
90e486c976 WIP - add invocations to support FLUX LORAs. 2024-09-09 21:00:35 +00:00
Ryan Dick
ceb5d50568 Get probing of FLUX LoRA kohya models working. 2024-09-09 20:58:27 +00:00
Ryan Dick
e4cca62a90 Add utility function for detecting whether a state_dict is in the FLUX kohya LoRA format. 2024-09-09 20:58:27 +00:00
Ryan Dick
008f672e47 Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test. 2024-09-09 20:58:27 +00:00
Ryan Dick
53ae86068c Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader. 2024-09-09 20:58:27 +00:00
Ryan Dick
7de3c1943f Remove unused LoRAModelRaw.name attribute. 2024-09-09 20:58:27 +00:00
Ryan Dick
a0f36dea31 Fix type errors in sdxl_conversion_utils.py 2024-09-09 20:58:27 +00:00
Ryan Dick
69a2f8d53d Start moving SDXL-specific LoRA conversions out of the general-purpose LoRAModelRaw class. 2024-09-09 20:58:27 +00:00
Ryan Dick
3f2a61e0a6 Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests. 2024-09-09 20:58:27 +00:00
Ryan Dick
f45b925bbf WIP - FLUX LoRA conversion logic. 2024-09-09 20:58:27 +00:00
Ryan Dick
d2870a512d Add state_dict keys for two FLUX LoRA formats to be used in unit tests. 2024-09-09 20:58:27 +00:00
Ryan Dick
e1e5f970e6 Move lora.py to peft/ subdir. 2024-09-09 20:58:27 +00:00
Ryan Dick
0712684dc9 Split PEFT layer implementations into separate files. 2024-09-09 20:58:27 +00:00
37 changed files with 3408 additions and 940 deletions

View File

@@ -19,7 +19,8 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -82,9 +83,10 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoraPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -177,9 +179,9 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoraPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),

View File

@@ -36,7 +36,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoraPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):

View File

@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
@@ -29,6 +29,8 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -187,7 +189,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with transformer_info as transformer:
with (
transformer_info.model_on_device() as (cached_weights, transformer),
# Apply the LoRA after transformer has been moved to its target device for faster patching.
LoraPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
),
):
assert isinstance(transformer, Flux)
x = denoise(
@@ -241,6 +252,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()

View File

@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoraPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@@ -1,672 +0,0 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

View File

@@ -0,0 +1,210 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
all_expected_keys_present = all(k in state_dict for k in expected_keys)
return all_keys_in_peft_format and all_expected_keys_present
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
layers: dict[str, AnyLoRALayer] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
layers[dst_key] = LoRALayer(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayerBase] = []
for src_layer_dict in src_layer_dicts:
sub_layers.append(
LoRALayer(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
)
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
# time_text_embed.text_embedder -> vector_in.
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
# x_embedder -> img_in.
add_lora_layer_if_present("x_embedder", "img_in")
# Double transformer blocks.
for i in range(num_double_layers):
# norms.
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
# Q, K, V
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.to_q",
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.add_q_proj",
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
f"double_blocks.{i}.txt_attn.qkv",
)
# ff img_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.0.proj",
f"double_blocks.{i}.img_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.2",
f"double_blocks.{i}.img_mlp.2",
)
# ff txt_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.0.proj",
f"double_blocks.{i}.txt_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.2",
f"double_blocks.{i}.txt_mlp.2",
)
# output projections.
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_out.0",
f"double_blocks.{i}.img_attn.proj",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_add_out",
f"double_blocks.{i}.txt_attn.proj",
)
# Single transformer blocks.
for i in range(num_single_layers):
# norms
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.norm.linear",
f"single_blocks.{i}.modulation.lin",
)
# Q, K, V, mlp
add_qkv_lora_layer_if_present(
[
f"single_transformer_blocks.{i}.attn.to_q",
f"single_transformer_blocks.{i}.attn.to_k",
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
f"single_blocks.{i}.linear1",
)
# Output projections.
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.proj_out",
f"single_blocks.{i}.linear2",
)
# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Groups the keys in the state dict by layer."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
for key in state_dict:
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])
if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]
return layer_dict

View File

@@ -0,0 +1,80 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -0,0 +1,29 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -0,0 +1,154 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

View File

@@ -0,0 +1,11 @@
from typing import Union
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]

View File

@@ -0,0 +1,46 @@
from typing import List, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
super().__init__(values={})
self._lora_layers = lora_layers
self._concat_axis = concat_axis
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original weight tensor here.
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self._concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
if any(layer_bias_is_none):
assert all(layer_bias_is_none)
return None
# Ignore the type error, because we have just verified that all layer biases are non-None.
return torch.cat(layer_biases, dim=self._concat_axis)
def calc_size(self) -> int:
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,36 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,41 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,68 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, values: Dict[str, torch.Tensor]):
super().__init__(values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,113 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,58 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,71 @@
from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)

View File

@@ -0,0 +1,36 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,33 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer(state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(state_dict)
elif "on_input" in state_dict:
return IA3Layer(state_dict)
elif "w_norm" in state_dict:
return NormLayer(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2024 The InvokeAI Development team
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
def __init__(self, layers: Dict[str, AnyLoRALayer]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size

View File

@@ -0,0 +1,148 @@
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoraPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model within a context manager.
:param model: The model to patch.
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
that the LoRA patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoraPatcher.apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@staticmethod
@torch.no_grad()
def apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply a single LoRA patch to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoraPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
:param model: The model to search.
:param layer_key: The layer key to search for.
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
searching, but some legacy code still uses flattened keys.
:return: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module

View File

@@ -5,8 +5,18 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -45,14 +55,33 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format == ModelFormat.Diffusers:
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
model.to(dtype=self._torch_dtype)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base

View File

@@ -15,7 +15,7 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel

View File

@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -244,7 +248,9 @@ class ModelProbe(object):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
@@ -554,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
# file, but the weight keys are in the diffusers format.
return ModelFormat.Diffusers
return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
self.checkpoint
):
return BaseModelType.Flux
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:

View File

@@ -5,32 +5,18 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
class ModelPatcher:
@@ -54,95 +40,6 @@ class ModelPatcher:
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
cached_weights=cached_weights,
):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
yield
finally:
with torch.no_grad():
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager
def apply_ti(
@@ -282,26 +179,6 @@ class ModelPatcher:
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod

View File

@@ -1,18 +1,17 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
assert isinstance(lora_model, LoRAModelRaw)
LoraPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@@ -0,0 +1,993 @@
# A sample state dict in the Diffusers FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/200255/hands-xl-sd-15-flux1-dev?modelVersionId=781855
state_dict_keys = [
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_B.weight",
]

View File

@@ -0,0 +1,917 @@
# A sample state dict in the Kohya FLUX LoRA format.
# These keys are based on the LoRA model here:
# https://civitai.com/models/159333/pokemon-trainer-sprite-pixelart?modelVersionId=779247
state_dict_keys = [
"lora_unet_double_blocks_0_img_attn_proj.alpha",
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_img_attn_qkv.alpha",
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_0.alpha",
"lora_unet_double_blocks_0_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_2.alpha",
"lora_unet_double_blocks_0_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_img_mod_lin.alpha",
"lora_unet_double_blocks_0_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_proj.alpha",
"lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.alpha",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_0.alpha",
"lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_2.alpha",
"lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_txt_mod_lin.alpha",
"lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_proj.alpha",
"lora_unet_double_blocks_10_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_qkv.alpha",
"lora_unet_double_blocks_10_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_0.alpha",
"lora_unet_double_blocks_10_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_2.alpha",
"lora_unet_double_blocks_10_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_img_mod_lin.alpha",
"lora_unet_double_blocks_10_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_proj.alpha",
"lora_unet_double_blocks_10_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.alpha",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_0.alpha",
"lora_unet_double_blocks_10_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_2.alpha",
"lora_unet_double_blocks_10_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_txt_mod_lin.alpha",
"lora_unet_double_blocks_10_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_proj.alpha",
"lora_unet_double_blocks_11_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_qkv.alpha",
"lora_unet_double_blocks_11_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_0.alpha",
"lora_unet_double_blocks_11_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_2.alpha",
"lora_unet_double_blocks_11_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_img_mod_lin.alpha",
"lora_unet_double_blocks_11_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_proj.alpha",
"lora_unet_double_blocks_11_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.alpha",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_0.alpha",
"lora_unet_double_blocks_11_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_2.alpha",
"lora_unet_double_blocks_11_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_txt_mod_lin.alpha",
"lora_unet_double_blocks_11_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_proj.alpha",
"lora_unet_double_blocks_12_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_qkv.alpha",
"lora_unet_double_blocks_12_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_0.alpha",
"lora_unet_double_blocks_12_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_2.alpha",
"lora_unet_double_blocks_12_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_img_mod_lin.alpha",
"lora_unet_double_blocks_12_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_proj.alpha",
"lora_unet_double_blocks_12_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.alpha",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_0.alpha",
"lora_unet_double_blocks_12_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_2.alpha",
"lora_unet_double_blocks_12_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_txt_mod_lin.alpha",
"lora_unet_double_blocks_12_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_proj.alpha",
"lora_unet_double_blocks_13_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_qkv.alpha",
"lora_unet_double_blocks_13_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_0.alpha",
"lora_unet_double_blocks_13_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_2.alpha",
"lora_unet_double_blocks_13_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_img_mod_lin.alpha",
"lora_unet_double_blocks_13_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_proj.alpha",
"lora_unet_double_blocks_13_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.alpha",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_0.alpha",
"lora_unet_double_blocks_13_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_2.alpha",
"lora_unet_double_blocks_13_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_txt_mod_lin.alpha",
"lora_unet_double_blocks_13_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_proj.alpha",
"lora_unet_double_blocks_14_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_qkv.alpha",
"lora_unet_double_blocks_14_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_0.alpha",
"lora_unet_double_blocks_14_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_2.alpha",
"lora_unet_double_blocks_14_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_img_mod_lin.alpha",
"lora_unet_double_blocks_14_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_proj.alpha",
"lora_unet_double_blocks_14_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.alpha",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_0.alpha",
"lora_unet_double_blocks_14_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_2.alpha",
"lora_unet_double_blocks_14_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_txt_mod_lin.alpha",
"lora_unet_double_blocks_14_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_proj.alpha",
"lora_unet_double_blocks_15_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_qkv.alpha",
"lora_unet_double_blocks_15_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_0.alpha",
"lora_unet_double_blocks_15_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_2.alpha",
"lora_unet_double_blocks_15_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_img_mod_lin.alpha",
"lora_unet_double_blocks_15_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_proj.alpha",
"lora_unet_double_blocks_15_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.alpha",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_0.alpha",
"lora_unet_double_blocks_15_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_2.alpha",
"lora_unet_double_blocks_15_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_txt_mod_lin.alpha",
"lora_unet_double_blocks_15_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_proj.alpha",
"lora_unet_double_blocks_16_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_qkv.alpha",
"lora_unet_double_blocks_16_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_0.alpha",
"lora_unet_double_blocks_16_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_2.alpha",
"lora_unet_double_blocks_16_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_img_mod_lin.alpha",
"lora_unet_double_blocks_16_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_proj.alpha",
"lora_unet_double_blocks_16_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.alpha",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_0.alpha",
"lora_unet_double_blocks_16_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_2.alpha",
"lora_unet_double_blocks_16_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_txt_mod_lin.alpha",
"lora_unet_double_blocks_16_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_proj.alpha",
"lora_unet_double_blocks_17_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_qkv.alpha",
"lora_unet_double_blocks_17_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_0.alpha",
"lora_unet_double_blocks_17_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_2.alpha",
"lora_unet_double_blocks_17_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_img_mod_lin.alpha",
"lora_unet_double_blocks_17_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_proj.alpha",
"lora_unet_double_blocks_17_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.alpha",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_0.alpha",
"lora_unet_double_blocks_17_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_2.alpha",
"lora_unet_double_blocks_17_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_txt_mod_lin.alpha",
"lora_unet_double_blocks_17_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_proj.alpha",
"lora_unet_double_blocks_18_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_qkv.alpha",
"lora_unet_double_blocks_18_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_0.alpha",
"lora_unet_double_blocks_18_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_2.alpha",
"lora_unet_double_blocks_18_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_img_mod_lin.alpha",
"lora_unet_double_blocks_18_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_proj.alpha",
"lora_unet_double_blocks_18_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.alpha",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_0.alpha",
"lora_unet_double_blocks_18_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_2.alpha",
"lora_unet_double_blocks_18_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_txt_mod_lin.alpha",
"lora_unet_double_blocks_18_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_proj.alpha",
"lora_unet_double_blocks_1_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_qkv.alpha",
"lora_unet_double_blocks_1_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_0.alpha",
"lora_unet_double_blocks_1_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_2.alpha",
"lora_unet_double_blocks_1_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_img_mod_lin.alpha",
"lora_unet_double_blocks_1_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_proj.alpha",
"lora_unet_double_blocks_1_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.alpha",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_0.alpha",
"lora_unet_double_blocks_1_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_2.alpha",
"lora_unet_double_blocks_1_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_txt_mod_lin.alpha",
"lora_unet_double_blocks_1_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_proj.alpha",
"lora_unet_double_blocks_2_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_qkv.alpha",
"lora_unet_double_blocks_2_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_0.alpha",
"lora_unet_double_blocks_2_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_2.alpha",
"lora_unet_double_blocks_2_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_img_mod_lin.alpha",
"lora_unet_double_blocks_2_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_proj.alpha",
"lora_unet_double_blocks_2_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.alpha",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_0.alpha",
"lora_unet_double_blocks_2_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_2.alpha",
"lora_unet_double_blocks_2_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_txt_mod_lin.alpha",
"lora_unet_double_blocks_2_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_proj.alpha",
"lora_unet_double_blocks_3_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_qkv.alpha",
"lora_unet_double_blocks_3_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_0.alpha",
"lora_unet_double_blocks_3_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_2.alpha",
"lora_unet_double_blocks_3_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_img_mod_lin.alpha",
"lora_unet_double_blocks_3_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_proj.alpha",
"lora_unet_double_blocks_3_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.alpha",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_0.alpha",
"lora_unet_double_blocks_3_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_2.alpha",
"lora_unet_double_blocks_3_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_txt_mod_lin.alpha",
"lora_unet_double_blocks_3_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_proj.alpha",
"lora_unet_double_blocks_4_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_qkv.alpha",
"lora_unet_double_blocks_4_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_0.alpha",
"lora_unet_double_blocks_4_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_2.alpha",
"lora_unet_double_blocks_4_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_img_mod_lin.alpha",
"lora_unet_double_blocks_4_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_proj.alpha",
"lora_unet_double_blocks_4_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.alpha",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_0.alpha",
"lora_unet_double_blocks_4_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_2.alpha",
"lora_unet_double_blocks_4_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_txt_mod_lin.alpha",
"lora_unet_double_blocks_4_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_proj.alpha",
"lora_unet_double_blocks_5_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_qkv.alpha",
"lora_unet_double_blocks_5_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_0.alpha",
"lora_unet_double_blocks_5_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_2.alpha",
"lora_unet_double_blocks_5_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_img_mod_lin.alpha",
"lora_unet_double_blocks_5_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_proj.alpha",
"lora_unet_double_blocks_5_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.alpha",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_0.alpha",
"lora_unet_double_blocks_5_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_2.alpha",
"lora_unet_double_blocks_5_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_txt_mod_lin.alpha",
"lora_unet_double_blocks_5_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_proj.alpha",
"lora_unet_double_blocks_6_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_qkv.alpha",
"lora_unet_double_blocks_6_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_0.alpha",
"lora_unet_double_blocks_6_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_2.alpha",
"lora_unet_double_blocks_6_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_img_mod_lin.alpha",
"lora_unet_double_blocks_6_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_proj.alpha",
"lora_unet_double_blocks_6_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.alpha",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_0.alpha",
"lora_unet_double_blocks_6_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_2.alpha",
"lora_unet_double_blocks_6_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_txt_mod_lin.alpha",
"lora_unet_double_blocks_6_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_proj.alpha",
"lora_unet_double_blocks_7_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_qkv.alpha",
"lora_unet_double_blocks_7_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_0.alpha",
"lora_unet_double_blocks_7_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_2.alpha",
"lora_unet_double_blocks_7_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_img_mod_lin.alpha",
"lora_unet_double_blocks_7_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_proj.alpha",
"lora_unet_double_blocks_7_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.alpha",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_0.alpha",
"lora_unet_double_blocks_7_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_2.alpha",
"lora_unet_double_blocks_7_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_txt_mod_lin.alpha",
"lora_unet_double_blocks_7_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_proj.alpha",
"lora_unet_double_blocks_8_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_qkv.alpha",
"lora_unet_double_blocks_8_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_0.alpha",
"lora_unet_double_blocks_8_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_2.alpha",
"lora_unet_double_blocks_8_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_img_mod_lin.alpha",
"lora_unet_double_blocks_8_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_proj.alpha",
"lora_unet_double_blocks_8_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.alpha",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_0.alpha",
"lora_unet_double_blocks_8_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_2.alpha",
"lora_unet_double_blocks_8_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_txt_mod_lin.alpha",
"lora_unet_double_blocks_8_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_proj.alpha",
"lora_unet_double_blocks_9_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_qkv.alpha",
"lora_unet_double_blocks_9_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_0.alpha",
"lora_unet_double_blocks_9_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_2.alpha",
"lora_unet_double_blocks_9_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_img_mod_lin.alpha",
"lora_unet_double_blocks_9_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_proj.alpha",
"lora_unet_double_blocks_9_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.alpha",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_0.alpha",
"lora_unet_double_blocks_9_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_2.alpha",
"lora_unet_double_blocks_9_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_txt_mod_lin.alpha",
"lora_unet_double_blocks_9_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_txt_mod_lin.lora_up.weight",
"lora_unet_single_blocks_0_linear1.alpha",
"lora_unet_single_blocks_0_linear1.lora_down.weight",
"lora_unet_single_blocks_0_linear1.lora_up.weight",
"lora_unet_single_blocks_0_linear2.alpha",
"lora_unet_single_blocks_0_linear2.lora_down.weight",
"lora_unet_single_blocks_0_linear2.lora_up.weight",
"lora_unet_single_blocks_0_modulation_lin.alpha",
"lora_unet_single_blocks_0_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_0_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_10_linear1.alpha",
"lora_unet_single_blocks_10_linear1.lora_down.weight",
"lora_unet_single_blocks_10_linear1.lora_up.weight",
"lora_unet_single_blocks_10_linear2.alpha",
"lora_unet_single_blocks_10_linear2.lora_down.weight",
"lora_unet_single_blocks_10_linear2.lora_up.weight",
"lora_unet_single_blocks_10_modulation_lin.alpha",
"lora_unet_single_blocks_10_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_10_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_11_linear1.alpha",
"lora_unet_single_blocks_11_linear1.lora_down.weight",
"lora_unet_single_blocks_11_linear1.lora_up.weight",
"lora_unet_single_blocks_11_linear2.alpha",
"lora_unet_single_blocks_11_linear2.lora_down.weight",
"lora_unet_single_blocks_11_linear2.lora_up.weight",
"lora_unet_single_blocks_11_modulation_lin.alpha",
"lora_unet_single_blocks_11_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_11_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_12_linear1.alpha",
"lora_unet_single_blocks_12_linear1.lora_down.weight",
"lora_unet_single_blocks_12_linear1.lora_up.weight",
"lora_unet_single_blocks_12_linear2.alpha",
"lora_unet_single_blocks_12_linear2.lora_down.weight",
"lora_unet_single_blocks_12_linear2.lora_up.weight",
"lora_unet_single_blocks_12_modulation_lin.alpha",
"lora_unet_single_blocks_12_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_12_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_13_linear1.alpha",
"lora_unet_single_blocks_13_linear1.lora_down.weight",
"lora_unet_single_blocks_13_linear1.lora_up.weight",
"lora_unet_single_blocks_13_linear2.alpha",
"lora_unet_single_blocks_13_linear2.lora_down.weight",
"lora_unet_single_blocks_13_linear2.lora_up.weight",
"lora_unet_single_blocks_13_modulation_lin.alpha",
"lora_unet_single_blocks_13_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_13_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_14_linear1.alpha",
"lora_unet_single_blocks_14_linear1.lora_down.weight",
"lora_unet_single_blocks_14_linear1.lora_up.weight",
"lora_unet_single_blocks_14_linear2.alpha",
"lora_unet_single_blocks_14_linear2.lora_down.weight",
"lora_unet_single_blocks_14_linear2.lora_up.weight",
"lora_unet_single_blocks_14_modulation_lin.alpha",
"lora_unet_single_blocks_14_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_14_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_15_linear1.alpha",
"lora_unet_single_blocks_15_linear1.lora_down.weight",
"lora_unet_single_blocks_15_linear1.lora_up.weight",
"lora_unet_single_blocks_15_linear2.alpha",
"lora_unet_single_blocks_15_linear2.lora_down.weight",
"lora_unet_single_blocks_15_linear2.lora_up.weight",
"lora_unet_single_blocks_15_modulation_lin.alpha",
"lora_unet_single_blocks_15_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_15_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_16_linear1.alpha",
"lora_unet_single_blocks_16_linear1.lora_down.weight",
"lora_unet_single_blocks_16_linear1.lora_up.weight",
"lora_unet_single_blocks_16_linear2.alpha",
"lora_unet_single_blocks_16_linear2.lora_down.weight",
"lora_unet_single_blocks_16_linear2.lora_up.weight",
"lora_unet_single_blocks_16_modulation_lin.alpha",
"lora_unet_single_blocks_16_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_16_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_17_linear1.alpha",
"lora_unet_single_blocks_17_linear1.lora_down.weight",
"lora_unet_single_blocks_17_linear1.lora_up.weight",
"lora_unet_single_blocks_17_linear2.alpha",
"lora_unet_single_blocks_17_linear2.lora_down.weight",
"lora_unet_single_blocks_17_linear2.lora_up.weight",
"lora_unet_single_blocks_17_modulation_lin.alpha",
"lora_unet_single_blocks_17_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_17_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_18_linear1.alpha",
"lora_unet_single_blocks_18_linear1.lora_down.weight",
"lora_unet_single_blocks_18_linear1.lora_up.weight",
"lora_unet_single_blocks_18_linear2.alpha",
"lora_unet_single_blocks_18_linear2.lora_down.weight",
"lora_unet_single_blocks_18_linear2.lora_up.weight",
"lora_unet_single_blocks_18_modulation_lin.alpha",
"lora_unet_single_blocks_18_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_18_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_19_linear1.alpha",
"lora_unet_single_blocks_19_linear1.lora_down.weight",
"lora_unet_single_blocks_19_linear1.lora_up.weight",
"lora_unet_single_blocks_19_linear2.alpha",
"lora_unet_single_blocks_19_linear2.lora_down.weight",
"lora_unet_single_blocks_19_linear2.lora_up.weight",
"lora_unet_single_blocks_19_modulation_lin.alpha",
"lora_unet_single_blocks_19_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_19_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_1_linear1.alpha",
"lora_unet_single_blocks_1_linear1.lora_down.weight",
"lora_unet_single_blocks_1_linear1.lora_up.weight",
"lora_unet_single_blocks_1_linear2.alpha",
"lora_unet_single_blocks_1_linear2.lora_down.weight",
"lora_unet_single_blocks_1_linear2.lora_up.weight",
"lora_unet_single_blocks_1_modulation_lin.alpha",
"lora_unet_single_blocks_1_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_1_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_20_linear1.alpha",
"lora_unet_single_blocks_20_linear1.lora_down.weight",
"lora_unet_single_blocks_20_linear1.lora_up.weight",
"lora_unet_single_blocks_20_linear2.alpha",
"lora_unet_single_blocks_20_linear2.lora_down.weight",
"lora_unet_single_blocks_20_linear2.lora_up.weight",
"lora_unet_single_blocks_20_modulation_lin.alpha",
"lora_unet_single_blocks_20_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_20_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_21_linear1.alpha",
"lora_unet_single_blocks_21_linear1.lora_down.weight",
"lora_unet_single_blocks_21_linear1.lora_up.weight",
"lora_unet_single_blocks_21_linear2.alpha",
"lora_unet_single_blocks_21_linear2.lora_down.weight",
"lora_unet_single_blocks_21_linear2.lora_up.weight",
"lora_unet_single_blocks_21_modulation_lin.alpha",
"lora_unet_single_blocks_21_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_21_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_22_linear1.alpha",
"lora_unet_single_blocks_22_linear1.lora_down.weight",
"lora_unet_single_blocks_22_linear1.lora_up.weight",
"lora_unet_single_blocks_22_linear2.alpha",
"lora_unet_single_blocks_22_linear2.lora_down.weight",
"lora_unet_single_blocks_22_linear2.lora_up.weight",
"lora_unet_single_blocks_22_modulation_lin.alpha",
"lora_unet_single_blocks_22_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_22_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_23_linear1.alpha",
"lora_unet_single_blocks_23_linear1.lora_down.weight",
"lora_unet_single_blocks_23_linear1.lora_up.weight",
"lora_unet_single_blocks_23_linear2.alpha",
"lora_unet_single_blocks_23_linear2.lora_down.weight",
"lora_unet_single_blocks_23_linear2.lora_up.weight",
"lora_unet_single_blocks_23_modulation_lin.alpha",
"lora_unet_single_blocks_23_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_23_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_24_linear1.alpha",
"lora_unet_single_blocks_24_linear1.lora_down.weight",
"lora_unet_single_blocks_24_linear1.lora_up.weight",
"lora_unet_single_blocks_24_linear2.alpha",
"lora_unet_single_blocks_24_linear2.lora_down.weight",
"lora_unet_single_blocks_24_linear2.lora_up.weight",
"lora_unet_single_blocks_24_modulation_lin.alpha",
"lora_unet_single_blocks_24_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_24_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_25_linear1.alpha",
"lora_unet_single_blocks_25_linear1.lora_down.weight",
"lora_unet_single_blocks_25_linear1.lora_up.weight",
"lora_unet_single_blocks_25_linear2.alpha",
"lora_unet_single_blocks_25_linear2.lora_down.weight",
"lora_unet_single_blocks_25_linear2.lora_up.weight",
"lora_unet_single_blocks_25_modulation_lin.alpha",
"lora_unet_single_blocks_25_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_25_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_26_linear1.alpha",
"lora_unet_single_blocks_26_linear1.lora_down.weight",
"lora_unet_single_blocks_26_linear1.lora_up.weight",
"lora_unet_single_blocks_26_linear2.alpha",
"lora_unet_single_blocks_26_linear2.lora_down.weight",
"lora_unet_single_blocks_26_linear2.lora_up.weight",
"lora_unet_single_blocks_26_modulation_lin.alpha",
"lora_unet_single_blocks_26_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_26_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_27_linear1.alpha",
"lora_unet_single_blocks_27_linear1.lora_down.weight",
"lora_unet_single_blocks_27_linear1.lora_up.weight",
"lora_unet_single_blocks_27_linear2.alpha",
"lora_unet_single_blocks_27_linear2.lora_down.weight",
"lora_unet_single_blocks_27_linear2.lora_up.weight",
"lora_unet_single_blocks_27_modulation_lin.alpha",
"lora_unet_single_blocks_27_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_27_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_28_linear1.alpha",
"lora_unet_single_blocks_28_linear1.lora_down.weight",
"lora_unet_single_blocks_28_linear1.lora_up.weight",
"lora_unet_single_blocks_28_linear2.alpha",
"lora_unet_single_blocks_28_linear2.lora_down.weight",
"lora_unet_single_blocks_28_linear2.lora_up.weight",
"lora_unet_single_blocks_28_modulation_lin.alpha",
"lora_unet_single_blocks_28_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_28_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_29_linear1.alpha",
"lora_unet_single_blocks_29_linear1.lora_down.weight",
"lora_unet_single_blocks_29_linear1.lora_up.weight",
"lora_unet_single_blocks_29_linear2.alpha",
"lora_unet_single_blocks_29_linear2.lora_down.weight",
"lora_unet_single_blocks_29_linear2.lora_up.weight",
"lora_unet_single_blocks_29_modulation_lin.alpha",
"lora_unet_single_blocks_29_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_29_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_2_linear1.alpha",
"lora_unet_single_blocks_2_linear1.lora_down.weight",
"lora_unet_single_blocks_2_linear1.lora_up.weight",
"lora_unet_single_blocks_2_linear2.alpha",
"lora_unet_single_blocks_2_linear2.lora_down.weight",
"lora_unet_single_blocks_2_linear2.lora_up.weight",
"lora_unet_single_blocks_2_modulation_lin.alpha",
"lora_unet_single_blocks_2_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_2_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_30_linear1.alpha",
"lora_unet_single_blocks_30_linear1.lora_down.weight",
"lora_unet_single_blocks_30_linear1.lora_up.weight",
"lora_unet_single_blocks_30_linear2.alpha",
"lora_unet_single_blocks_30_linear2.lora_down.weight",
"lora_unet_single_blocks_30_linear2.lora_up.weight",
"lora_unet_single_blocks_30_modulation_lin.alpha",
"lora_unet_single_blocks_30_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_30_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_31_linear1.alpha",
"lora_unet_single_blocks_31_linear1.lora_down.weight",
"lora_unet_single_blocks_31_linear1.lora_up.weight",
"lora_unet_single_blocks_31_linear2.alpha",
"lora_unet_single_blocks_31_linear2.lora_down.weight",
"lora_unet_single_blocks_31_linear2.lora_up.weight",
"lora_unet_single_blocks_31_modulation_lin.alpha",
"lora_unet_single_blocks_31_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_31_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_32_linear1.alpha",
"lora_unet_single_blocks_32_linear1.lora_down.weight",
"lora_unet_single_blocks_32_linear1.lora_up.weight",
"lora_unet_single_blocks_32_linear2.alpha",
"lora_unet_single_blocks_32_linear2.lora_down.weight",
"lora_unet_single_blocks_32_linear2.lora_up.weight",
"lora_unet_single_blocks_32_modulation_lin.alpha",
"lora_unet_single_blocks_32_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_32_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_33_linear1.alpha",
"lora_unet_single_blocks_33_linear1.lora_down.weight",
"lora_unet_single_blocks_33_linear1.lora_up.weight",
"lora_unet_single_blocks_33_linear2.alpha",
"lora_unet_single_blocks_33_linear2.lora_down.weight",
"lora_unet_single_blocks_33_linear2.lora_up.weight",
"lora_unet_single_blocks_33_modulation_lin.alpha",
"lora_unet_single_blocks_33_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_33_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_34_linear1.alpha",
"lora_unet_single_blocks_34_linear1.lora_down.weight",
"lora_unet_single_blocks_34_linear1.lora_up.weight",
"lora_unet_single_blocks_34_linear2.alpha",
"lora_unet_single_blocks_34_linear2.lora_down.weight",
"lora_unet_single_blocks_34_linear2.lora_up.weight",
"lora_unet_single_blocks_34_modulation_lin.alpha",
"lora_unet_single_blocks_34_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_34_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_35_linear1.alpha",
"lora_unet_single_blocks_35_linear1.lora_down.weight",
"lora_unet_single_blocks_35_linear1.lora_up.weight",
"lora_unet_single_blocks_35_linear2.alpha",
"lora_unet_single_blocks_35_linear2.lora_down.weight",
"lora_unet_single_blocks_35_linear2.lora_up.weight",
"lora_unet_single_blocks_35_modulation_lin.alpha",
"lora_unet_single_blocks_35_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_35_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_36_linear1.alpha",
"lora_unet_single_blocks_36_linear1.lora_down.weight",
"lora_unet_single_blocks_36_linear1.lora_up.weight",
"lora_unet_single_blocks_36_linear2.alpha",
"lora_unet_single_blocks_36_linear2.lora_down.weight",
"lora_unet_single_blocks_36_linear2.lora_up.weight",
"lora_unet_single_blocks_36_modulation_lin.alpha",
"lora_unet_single_blocks_36_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_36_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_37_linear1.alpha",
"lora_unet_single_blocks_37_linear1.lora_down.weight",
"lora_unet_single_blocks_37_linear1.lora_up.weight",
"lora_unet_single_blocks_37_linear2.alpha",
"lora_unet_single_blocks_37_linear2.lora_down.weight",
"lora_unet_single_blocks_37_linear2.lora_up.weight",
"lora_unet_single_blocks_37_modulation_lin.alpha",
"lora_unet_single_blocks_37_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_37_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_3_linear1.alpha",
"lora_unet_single_blocks_3_linear1.lora_down.weight",
"lora_unet_single_blocks_3_linear1.lora_up.weight",
"lora_unet_single_blocks_3_linear2.alpha",
"lora_unet_single_blocks_3_linear2.lora_down.weight",
"lora_unet_single_blocks_3_linear2.lora_up.weight",
"lora_unet_single_blocks_3_modulation_lin.alpha",
"lora_unet_single_blocks_3_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_3_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_4_linear1.alpha",
"lora_unet_single_blocks_4_linear1.lora_down.weight",
"lora_unet_single_blocks_4_linear1.lora_up.weight",
"lora_unet_single_blocks_4_linear2.alpha",
"lora_unet_single_blocks_4_linear2.lora_down.weight",
"lora_unet_single_blocks_4_linear2.lora_up.weight",
"lora_unet_single_blocks_4_modulation_lin.alpha",
"lora_unet_single_blocks_4_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_4_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_5_linear1.alpha",
"lora_unet_single_blocks_5_linear1.lora_down.weight",
"lora_unet_single_blocks_5_linear1.lora_up.weight",
"lora_unet_single_blocks_5_linear2.alpha",
"lora_unet_single_blocks_5_linear2.lora_down.weight",
"lora_unet_single_blocks_5_linear2.lora_up.weight",
"lora_unet_single_blocks_5_modulation_lin.alpha",
"lora_unet_single_blocks_5_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_5_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_6_linear1.alpha",
"lora_unet_single_blocks_6_linear1.lora_down.weight",
"lora_unet_single_blocks_6_linear1.lora_up.weight",
"lora_unet_single_blocks_6_linear2.alpha",
"lora_unet_single_blocks_6_linear2.lora_down.weight",
"lora_unet_single_blocks_6_linear2.lora_up.weight",
"lora_unet_single_blocks_6_modulation_lin.alpha",
"lora_unet_single_blocks_6_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_6_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_7_linear1.alpha",
"lora_unet_single_blocks_7_linear1.lora_down.weight",
"lora_unet_single_blocks_7_linear1.lora_up.weight",
"lora_unet_single_blocks_7_linear2.alpha",
"lora_unet_single_blocks_7_linear2.lora_down.weight",
"lora_unet_single_blocks_7_linear2.lora_up.weight",
"lora_unet_single_blocks_7_modulation_lin.alpha",
"lora_unet_single_blocks_7_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_7_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_8_linear1.alpha",
"lora_unet_single_blocks_8_linear1.lora_down.weight",
"lora_unet_single_blocks_8_linear1.lora_up.weight",
"lora_unet_single_blocks_8_linear2.alpha",
"lora_unet_single_blocks_8_linear2.lora_down.weight",
"lora_unet_single_blocks_8_linear2.lora_up.weight",
"lora_unet_single_blocks_8_modulation_lin.alpha",
"lora_unet_single_blocks_8_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_8_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_9_linear1.alpha",
"lora_unet_single_blocks_9_linear1.lora_down.weight",
"lora_unet_single_blocks_9_linear1.lora_up.weight",
"lora_unet_single_blocks_9_linear2.alpha",
"lora_unet_single_blocks_9_linear2.lora_down.weight",
"lora_unet_single_blocks_9_linear2.lora_up.weight",
"lora_unet_single_blocks_9_modulation_lin.alpha",
"lora_unet_single_blocks_9_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_9_modulation_lin.lora_up.weight",
]

View File

@@ -0,0 +1,8 @@
import torch
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
state_dict: dict[str, torch.Tensor] = {}
for k in keys:
state_dict[k] = torch.empty(1)
return state_dict

View File

@@ -0,0 +1,66 @@
import pytest
import torch
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_diffusers_format_true():
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_lora_model_from_flux_diffusers_state_dict():
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_diffusers_state_dict(state_dict)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in flux_diffusers_state_dict_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_diffusers_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_diffusers_state_dict(state_dict)

View File

@@ -0,0 +1,98 @@
import pytest
import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true():
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with torch.device("meta"):
model = Flux(params["flux-dev"])
model_keys = set(model.state_dict().keys())
# Assert that the converted state dict matches the keys in the actual model.
for converted_key_prefix in converted_key_prefixes:
found_match = False
for model_key in model_keys:
if model_key.startswith(converted_key_prefix):
found_match = True
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in flux_kohya_state_dict_keys:
k = k.replace("lora_unet_", "")
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
expected_layer_keys.add(k)
# Assert that the lora_model has the expected layers.
lora_model_keys = set(lora_model.layers.keys())
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
assert lora_model_keys == expected_layer_keys

View File

@@ -1,12 +1,9 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoraPatcher
@pytest.mark.parametrize(
@@ -17,7 +14,7 @@ from invokeai.backend.model_patcher import ModelPatcher
],
)
@torch.no_grad()
def test_apply_lora(device):
def test_apply_lora(device: str):
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
result, and that model/LoRA tensors are moved between devices as expected.
"""
@@ -31,20 +28,19 @@ def test_apply_lora(device):
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora = LoRAModelRaw(lora_layers)
lora_weight = 0.5
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
with LoraPatcher.apply_lora_patches(model=model, patches=[(lora, lora_weight)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
@@ -75,18 +71,17 @@ def test_apply_lora_change_device():
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora = LoRAModelRaw(lora_layers)
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
with LoraPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"