Compare commits

...

15 Commits

Author SHA1 Message Date
Ryan Dick
3ed6e65a6e Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. 2024-12-12 22:41:50 +00:00
Ryan Dick
52c9646f84 (minor) Rename num_layers -> num_loras in unit tests. 2024-12-12 22:41:50 +00:00
Ryan Dick
7662f0522b Add test_apply_smart_lora_patches_to_partially_loaded_model(...). 2024-12-12 22:41:50 +00:00
Ryan Dick
e50fe69839 Add LoRAPatcher.smart_apply_lora_patches() 2024-12-12 22:41:50 +00:00
Ryan Dick
5a9f884620 Refactor LoRAPatcher slightly in preparation for a 'smart' patcher. 2024-12-12 22:41:46 +00:00
Ryan Dick
edc72d1739 Fix LoRAPatcher.apply_lora_wrapper_patches(...) 2024-12-12 22:33:07 +00:00
Ryan Dick
23f521dc7c Finish consolidating LoRA sidecar wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3d6b93efdd Begin to consolidate the LoRA sidecar and LoRA layer wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3f28d3afad Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs. 2024-12-12 22:33:07 +00:00
Ryan Dick
9353bfbdd6 Add LoRA wrapper patching to LoRAPatcher. 2024-12-12 22:33:07 +00:00
Ryan Dick
93f2bc6118 Add LoRA wrapper layer. 2024-12-12 22:33:07 +00:00
Ryan Dick
9019026d6d Fixes to get FLUX Control LoRA working. 2024-12-12 00:19:39 +00:00
Brandon Rising
c195b326ec Lots of updates centered around using the lora patcher rather than changing the modules in the transformer model 2024-12-11 14:14:50 -05:00
Brandon Rising
2f460d2a45 Support bnb quantized nf4 flux models, Use controlnet vae, only support 1 structural lora per transformer. various other refractors and bugfixes 2024-12-10 03:26:29 -05:00
Brandon Rising
4473cba512 Initial setup for flux tools control loras 2024-12-09 16:01:29 -05:00
53 changed files with 2580 additions and 314 deletions

View File

@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.

View File

@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
dtype=unet.dtype,
cached_weights=cached_weights,
),
):

View File

@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
StructuralLoRAModel = "StructuralLoRAModelField"
# endregion
# region Misc Field Types
@@ -143,6 +144,7 @@ class FieldDescriptions:
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
structural_lora_model = "Structural LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
from typing import Callable, Iterator, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
@@ -8,6 +8,8 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@@ -22,7 +24,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -43,6 +45,8 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@@ -284,6 +288,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=x.device,
)
img_cond = None
if struct_lora := self.transformer.structural_lora:
# What should we do when we have multiple of these?
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
@@ -296,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
@@ -311,7 +326,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
LoRAPatcher.apply_lora_wrapper_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -345,6 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond
)
x = unpack(x.float(), self.height, self.width)
@@ -682,7 +698,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)

View File

@@ -81,8 +81,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@@ -0,0 +1,70 @@
from typing import Optional, Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_structural_lora_loader_output")
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
"""Flux Structural LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_structural_lora_loader",
title="Flux Structural LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
image: ImageField = InputField(
description="The image to encode.",
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@@ -22,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Optional, Literal
from pydantic import BaseModel, Field
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
@@ -79,6 +74,13 @@ class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):

View File

@@ -21,6 +21,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)

View File

@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LoRAPatcher.apply_smart_lora_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@@ -30,6 +30,8 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens
img_cond: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -69,9 +71,9 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=img,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,

View File

@@ -0,0 +1,27 @@
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
height: int,
width: int,
seed: int,
ae: AutoEncoder,
cond_image: Image.Image,
) -> torch.Tensor:
# load and encode the conditioning image
img_cond = cond_image.convert("RGB")
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float()
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
ae_dtype = next(iter(ae.parameters())).dtype
ae_device = next(iter(ae.parameters())).device
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
generator = torch.Generator(device=ae_device).manual_seed(seed)
img_cond = ae.encode(img_cond, sample=True, generator=generator)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return img_cond

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from typing import Optional
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,
@@ -35,6 +36,7 @@ class FluxParams:
theta: int
qkv_bias: bool
guidance_embed: bool
out_channels: Optional[int] = None
class Flux(nn.Module):
@@ -47,7 +49,7 @@ class Flux(nn.Module):
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
self.out_channels = params.out_channels or self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads

View File

@@ -0,0 +1,50 @@
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)

View File

@@ -0,0 +1,65 @@
import re
import torch
from typing import Any, Dict
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
# guidance_in.in_layer.lora_B.bias
# single_blocks.0.linear1.lora_A.weight
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
key_props = key.split(".")
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
# Convert to a full layer diff
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"]
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@@ -7,5 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]

View File

@@ -0,0 +1,34 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.bias = torch.nn.Parameter(bias) if bias is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
if self.weight is not None:
self.weight = self.weight.to(device=device, dtype=dtype)
if self.manual_scale is not None:
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.manual_scale)

View File

@@ -0,0 +1,29 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetParameterLayer(LoRALayerBase):
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__(None, None)
self.weight = weight
self.param_name = param_name
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@@ -9,6 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:

View File

@@ -0,0 +1,133 @@
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRASidecarWrapper(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
self._lora_weights = lora_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
self._lora_layers.append(lora_layer)
self._lora_weights.append(lora_weight)
@torch.no_grad()
def _get_lora_patched_parameters(
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
) -> dict[str, torch.Tensor]:
params: dict[str, torch.Tensor] = {}
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
if orig_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(orig_params[param_name].shape)
if param_name not in params:
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
else:
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
return params
class LoRALinearWrapper(LoRASidecarWrapper):
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Split the LoRA layers into those that have optimized implementations and those that don't.
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if isinstance(layer, optimized_layer_types)
]
non_optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if not isinstance(layer, optimized_layer_types)
]
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
residual = None
for lora_layer, lora_weight in optimized_layers:
if isinstance(lora_layer, LoRALayer):
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
else:
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
if residual is None:
residual = added_residual
else:
residual += added_residual
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
if non_optimized_layers:
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=unoptimized_layers,
lora_weights=unoptimized_weights,
)
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
if residual is None:
residual = added_residual
else:
residual += added_residual
return self.orig_module(input) + residual
class LoRAConv1dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
class LoRAConv2dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))

View File

@@ -4,19 +4,126 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
from invokeai.backend.lora.lora_layer_wrappers import (
LoRAConv1dWrapper,
LoRAConv2dWrapper,
LoRALinearWrapper,
LoRASidecarWrapper,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_smart_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_smart_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore directly patched layers.
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
# Restore LoRASidecarWrapper modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
@torch.no_grad()
def _apply_smart_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Decide whether to use direct patching or a sidecar wrapper.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is already wrapped in a LoRASidecarWrapper.
# - The module is quantized.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
else:
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
@contextmanager
@@ -40,7 +147,7 @@ class LoRAPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoRAPatcher.apply_lora_patch(
LoRAPatcher._apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -52,11 +159,12 @@ class LoRAPatcher:
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
@staticmethod
@torch.no_grad()
def apply_lora_patch(
def _apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
@@ -91,48 +199,84 @@ class LoRAPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
layer_scale = layer.scale()
@staticmethod
@torch.no_grad()
def _apply_lora_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: AnyLoRALayer,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
first_param = next(module_to_patch.parameters())
device = first_param.device
dtype = first_param.dtype
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer_scale = patch.scale()
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
patch.to(device=device)
patch.to(dtype=torch.float32)
# Save original weight
original_weights.save(param_key, module_param)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
if module_param.shape != lora_param_weight.shape:
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
if module_param.nelement() == lora_param_weight.nelement():
lora_param_weight = lora_param_weight.reshape(module_param.shape)
else:
# This condition was added to handle layers in FLUX control LoRAs.
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
# to worry about this?
expanded_weight = torch.zeros_like(
lora_param_weight, dtype=module_param.dtype, device=module_param.device
)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(
module,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
def apply_lora_wrapper_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
runtime overhead compared to normal LoRA patching, but they enable:
- LoRA layers to be applied to quantized models
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
avoid doubling the memory requirements).
Args:
model (torch.nn.Module): The model to patch.
@@ -140,14 +284,11 @@ class LoRAPatcher:
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_lora_sidecar_patch(
LoRAPatcher._apply_lora_wrapper_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -165,7 +306,7 @@ class LoRAPatcher:
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
def _apply_lora_wrapper_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
@@ -173,7 +314,7 @@ class LoRAPatcher:
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
"""Apply a single LoRA wrapper patch to a model."""
if patch_weight == 0:
return
@@ -194,28 +335,47 @@ class LoRAPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
# Replace the original module with a LoRASidecarModule if it has not already been done.
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
lora_sidecar_module = module
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [])
original_modules[module_key] = module
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
@staticmethod
@torch.no_grad()
def _apply_lora_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: AnyLoRALayer,
patch_weight: float,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, LoRASidecarWrapper):
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
orig_module = module_to_patch
else:
assert module_to_patch_key in original_modules
lora_wrapper_layer = module_to_patch
orig_module = module_to_patch.orig_module
# Add the LoRA sidecar layer to the LoRASidecarModule.
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
# Move the LoRA layer to the same device/dtype as the orig module.
patch.to(device=orig_module.weight.device, dtype=dtype)
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
@@ -236,17 +396,13 @@ class LoRAPatcher:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear) or (
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
if isinstance(orig_layer, torch.nn.Linear):
return LoRALinearWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv1d):
return LoRAConv1dWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv2d):
return LoRAConv2dWrapper(orig_layer, [], [])
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")

View File

@@ -1,34 +0,0 @@
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
concatenated_lora_layer: ConcatenatedLoRALayer,
weight: float,
):
super().__init__()
self._concatenated_lora_layer = concatenated_lora_layer
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= self._weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert self._concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._concatenated_lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,27 +0,0 @@
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
lora_layer: LoRALayer,
weight: float,
):
super().__init__()
self._lora_layer = lora_layer
self._weight = weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.linear(x, self._lora_layer.down)
if self._lora_layer.mid is not None:
x = torch.nn.functional.linear(x, self._lora_layer.mid)
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
x *= self._weight * self._lora_layer.scale()
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,24 +0,0 @@
import torch
class LoRASidecarModule(torch.nn.Module):
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self.orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.orig_module(input)
for lora_layer in self._lora_layers:
x += lora_layer(input)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._orig_module.to(device=device, dtype=dtype)
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -67,6 +67,7 @@ class ModelType(str, Enum):
Main = "main"
VAE = "vae"
LoRA = "lora"
StructuralLoRa = "structural_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@@ -273,6 +274,18 @@ class LoRALyCORISConfig(LoRAConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class StructuralLoRALyCORISConfig(ModelConfigBase):
"""Model config for Structural LoRA/Lycoris models."""
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(LoRAConfigBase):
"""Model config for LoRA/Diffusers models."""
@@ -535,6 +548,7 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

View File

@@ -13,8 +13,9 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
@@ -32,6 +33,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""
@@ -75,7 +77,10 @@ class LoRALoader(ModelLoader):
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:

View File

@@ -18,6 +18,7 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlab
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
@@ -258,6 +259,18 @@ class ModelProbe(object):
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
if (
tensor_a is not None
and isinstance(tensor_a, torch.Tensor)
and tensor_a.shape[1] == 128
and tensor_b is not None
and isinstance(tensor_b, torch.Tensor)
and tensor_b.shape[0] == 3072
):
return ModelType.StructuralLoRa
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(
(
@@ -624,8 +637,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
self.checkpoint
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):
return BaseModelType.Flux
@@ -1046,6 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@@ -809,6 +809,7 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"structuralLora": "Structural LoRA",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",

View File

@@ -24,6 +24,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
@@ -75,6 +76,7 @@ export type ParamsState = {
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
structuralLora: ParameterStructuralLoRAModel | null;
};
const initialState: ParamsState = {
@@ -121,6 +123,7 @@ const initialState: ParamsState = {
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
structuralLora: null,
};
export const paramsSlice = createSlice({
@@ -195,6 +198,9 @@ export const paramsSlice = createSlice({
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
state.structuralLora = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
},

View File

@@ -46,6 +46,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
@@ -80,6 +81,7 @@ import {
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isStructuralLoRAModelConfig,
isT2IAdapterModelConfig,
isVAEModelConfig,
} from 'services/api/types';
@@ -226,6 +228,14 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
return modelIdentifier;
};
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'structural_lora', undefined);
const key = await getModelKey(slora, 'structural_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
return modelIdentifier;
};
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
@@ -671,6 +681,7 @@ export const parsers = {
mainModel: parseMainModel,
refinerModel: parseRefinerModel,
vaeModel: parseVAEModel,
structuralLora: parseStructuralLoRAModel,
lora: parseLoRA,
loras: parseAllLoRAs,
controlNet: parseControlNet,

View File

@@ -18,6 +18,7 @@ import {
useMainModels,
useRefinerModels,
useSpandrelImageToImageModels,
useStructuralLoRAModel,
useT2IAdapterModels,
useT5EncoderModels,
useVAEModels,
@@ -92,6 +93,12 @@ const ModelList = () => {
[t5EncoderModels, searchTerm, filteredModelType]
);
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
const filteredStructuralLoRAModels = useMemo(
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
[structuralLoRAModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
@@ -118,7 +125,8 @@ const ModelList = () => {
filteredVAEModels.length +
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
clipEmbedModels.length
clipEmbedModels.length +
structuralLoRAModels.length
);
}, [
filteredControlNetModels.length,
@@ -133,6 +141,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
structuralLoRAModels.length,
]);
return (
@@ -195,6 +204,15 @@ const ModelList = () => {
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Structural Lora List */}
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.structuralLora')}
modelList={filteredStructuralLoRAModels}
key="structural-lora"
/>
)}
{/* Clip Embed List */}
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (

View File

@@ -24,6 +24,7 @@ export const ModelTypeFilter = memo(() => {
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
structural_lora: t('modelManager.structuralLora'),
}),
[t]
);

View File

@@ -51,6 +51,8 @@ import {
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStructuralLoRAModelFieldInputInstance,
isStructuralLoRAModelFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@@ -81,6 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@@ -156,6 +159,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
);
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,65 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StructuralLoRAModelFieldInputInstance,
StructuralLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
const _onChange = useCallback(
(value: StructuralLoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldStructuralLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(StructuralLoRAModelFieldInputComponent);

View File

@@ -28,6 +28,7 @@ import type {
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldValue,
StructuralLoRAModelFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
@@ -55,6 +56,7 @@ import {
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldValue,
zStructuralLoRAModelFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
@@ -369,6 +371,9 @@ export const nodesSlice = createSlice({
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
@@ -438,6 +443,7 @@ export const {
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldStructuralLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,

View File

@@ -69,6 +69,7 @@ const zModelType = z.enum([
'main',
'vae',
'lora',
'structural_lora',
'controlnet',
't2i_adapter',
'ip_adapter',

View File

@@ -178,6 +178,10 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('StructuralLoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -210,6 +214,7 @@ const zStatefulFieldType = z.union([
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zStructuralLoRAModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@@ -864,6 +869,29 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
// #endregion
// #region StructuralLoRAModelField
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStructuralLoRAModelFieldValue,
});
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStructuralLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zStructuralLoRAModelFieldValue,
});
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -959,6 +987,7 @@ export const zStatefulFieldValue = z.union([
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zStructuralLoRAModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -1030,6 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zStructuralLoRAModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -28,6 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
StructuralLoRAModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -28,6 +28,7 @@ import type {
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldInputTemplate,
StructuralLoRAModelFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
@@ -300,6 +301,20 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
return template;
};
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StructuralLoRAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -526,6 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@@ -113,6 +113,11 @@ export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion
// #region Structural Lora Model
export const zParameterStructuralLoRAModel = zModelIdentifierField;
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
// #endregion
// #region T5Encoder Model
export const zParameterT5EncoderModel = zModelIdentifierField;
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;

View File

@@ -23,6 +23,7 @@ import {
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isStructuralLoRAModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
isTIModelConfig,
@@ -58,6 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -44,6 +44,7 @@ export type BaseModelType = S['BaseModelType'];
// Model Configs
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
// TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
@@ -63,6 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| StructuralLoRAModelConfig
| LoRAModelConfig
| VAEModelConfig
| ControlNetModelConfig
@@ -114,6 +116,10 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
return config.type === 'lora';
};
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
return config.type === 'structural_lora';
};
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
import pytest
import torch
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
state_dict_keys as flux_control_lora_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
# Construct a state dict that is not in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_control_state_dict(state_dict)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in sd_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
k = k.replace("lora_B.bias", "")
k = k.replace(".scale", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_control_state_dict(state_dict)

View File

@@ -1,49 +0,0 @@
import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
def test_concatenated_lora_linear_sidecar_layer():
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
layer.
"""
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a ConcatenatedLoRALinearSidecarLayer.
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
# equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -1,38 +0,0 @@
import copy
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
@torch.no_grad()
def test_lora_linear_sidecar_layer():
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LoRALinearSidecarLayer.
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -0,0 +1,69 @@
import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
@torch.no_grad()
def test_lora_linear_wrapper():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LoRALinearWrapper.
lora_wrapped = LoRALinearWrapper(linear, [lora_layer], [1.0])
# Run the LoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_concatenated_lora_linear_wrapper():
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a LoRALinearWrapper.
lora_wrapped = LoRALinearWrapper(linear, [concatenated_lora_layer], [1.0])
# Run the ConcatenatedLoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)

View File

@@ -2,11 +2,15 @@ import pytest
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_layer_wrappers import LoRASidecarWrapper
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
class DummyModule(torch.nn.Module):
class DummyModuleWithOneLayer(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
@@ -15,8 +19,18 @@ class DummyModule(torch.nn.Module):
return self.linear_layer_1(x)
class DummyModuleWithTwoLayers(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
super().__init__()
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layer_2(self.linear_layer_1(x))
@pytest.mark.parametrize(
["device", "num_layers"],
["device", "num_loras"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
@@ -25,7 +39,7 @@ class DummyModule(torch.nn.Module):
],
)
@torch.no_grad()
def test_apply_lora_patches(device: str, num_layers: int):
def test_apply_lora_patches(device: str, num_loras: int):
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
correct result, and that model/LoRA tensors are moved between devices as expected.
"""
@@ -33,12 +47,12 @@ def test_apply_lora_patches(device: str, num_layers: int):
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
# Initialize num_layers LoRA models with weights of 0.5.
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
@@ -51,7 +65,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
@@ -79,7 +93,7 @@ def test_apply_lora_patches_change_device():
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
@@ -110,7 +124,7 @@ def test_apply_lora_patches_change_device():
@pytest.mark.parametrize(
["device", "num_layers"],
["device", "num_loras"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
@@ -118,18 +132,18 @@ def test_apply_lora_patches_change_device():
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
def test_apply_lora_wrapper_patches(device: str, num_loras: int):
"""Test the basic behavior of ModelPatcher.apply_lora_wrapper_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
@@ -146,7 +160,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
@@ -159,20 +173,140 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(
["device", "num_loras"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
def test_apply_smart_lora_patches(device: str, num_loras: int):
"""Test the basic behavior of ModelPatcher.apply_smart_lora_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
@torch.no_grad()
def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
CachedModelWithPartialLoad that is partially loaded into VRAM.
"""
if not torch.cuda.is_available():
pytest.skip("requires CUDA device")
# Initialize the model on the CPU.
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
_ = cached_model.partial_load_to_vram(target_vram_bytes)
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
"linear_layer_2": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
),
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
output_before_patch = cached_model.model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_smart_lora_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
# Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not.
assert not isinstance(cached_model.model.linear_layer_1, LoRASidecarWrapper)
assert isinstance(cached_model.model.linear_layer_2, LoRASidecarWrapper)
output_during_patch = cached_model.model(input)
# Run inference after unpatching.
output_after_patch = cached_model.model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
def test_all_patching_methods_produce_same_output(num_loras: int):
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
@@ -189,9 +323,13 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_wrapper_patches = model(input)
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_smart_lora_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)