Push LoRA layer reshaping down into the patch layers and add a new FluxControlLoRALayer type.

This commit is contained in:
Ryan Dick
2024-12-14 01:00:22 +00:00
parent fe09f2d27a
commit 37e3089457
8 changed files with 124 additions and 24 deletions

View File

@@ -0,0 +1,19 @@
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
class FluxControlLoRALayer(LoRALayer):
"""A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the
shapes don't match.
"""
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
"""This overrides the base class behavior to skip the reshaping step."""
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias * (weight * scale)
return params

View File

@@ -63,6 +63,13 @@ class LoRALayerBase(BaseLayerPatch):
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_module.get_parameter(param_name)
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
return params
@classmethod

View File

@@ -4,7 +4,9 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
from invokeai.backend.util.devices import TorchDevice
@@ -125,24 +127,18 @@ class LoRAPatcher:
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != param_weight.shape:
if module_param.nelement() == param_weight.nelement():
param_weight = param_weight.reshape(module_param.shape)
else:
# This condition was added to handle layers in FLUX control LoRAs.
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
# to worry about this?
expanded_weight = torch.zeros_like(
param_weight, dtype=module_param.dtype, device=module_param.device
)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(
module_to_patch,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
# shape of the original layer.
if module_param.nelement() != param_weight.nelement():
assert isinstance(patch, FluxControlLoRALayer)
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
setattr(
module_to_patch,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
module_param += param_weight.to(dtype=dtype)
patch.to(device=TorchDevice.CPU_DEVICE)

View File

@@ -0,0 +1,9 @@
import torch
def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
"""Pad a weight tensor with zeros to match the target shape."""
expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device)
slices = tuple(slice(0, dim) for dim in orig_weight.shape)
expanded_weight[slices] = orig_weight
return expanded_weight

View File

@@ -43,11 +43,6 @@ class BaseSidecarWrapper(torch.nn.Module):
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
for param_name, param_weight in layer_params.items():
orig_param = self._orig_module.get_parameter(param_name)
# TODO(ryand): Move shape handling down into the patch.
if orig_param.shape != param_weight.shape:
param_weight = param_weight.reshape(orig_param.shape)
if param_name not in params:
params[param_name] = param_weight
else:

View File

@@ -2,6 +2,7 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
@@ -36,12 +37,19 @@ class LinearSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : self.orig_module.weight.shape[1]]
output = self.orig_module(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in self._patches_and_weights:
if isinstance(patch, LoRALayer):
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += self._lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += self._lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += self._concatenated_lora_forward(input, patch, patch_weight)