mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 06:57:57 -05:00
Push LoRA layer reshaping down into the patch layers and add a new FluxControlLoRALayer type.
This commit is contained in:
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class FluxControlLoRALayer(LoRALayer):
|
||||
"""A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the
|
||||
shapes don't match.
|
||||
"""
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
"""This overrides the base class behavior to skip the reshaping step."""
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
return params
|
||||
@@ -63,6 +63,13 @@ class LoRALayerBase(BaseLayerPatch):
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
# Reshape all params to match the original module's shape.
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_module.get_parameter(param_name)
|
||||
if param_weight.shape != orig_param.shape:
|
||||
params[param_name] = param_weight.reshape(orig_param.shape)
|
||||
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -125,24 +127,18 @@ class LoRAPatcher:
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != param_weight.shape:
|
||||
if module_param.nelement() == param_weight.nelement():
|
||||
param_weight = param_weight.reshape(module_param.shape)
|
||||
else:
|
||||
# This condition was added to handle layers in FLUX control LoRAs.
|
||||
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
|
||||
# to worry about this?
|
||||
expanded_weight = torch.zeros_like(
|
||||
param_weight, dtype=module_param.dtype, device=module_param.device
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_param.shape)
|
||||
expanded_weight[slices] = module_param
|
||||
setattr(
|
||||
module_to_patch,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
|
||||
# shape of the original layer.
|
||||
if module_param.nelement() != param_weight.nelement():
|
||||
assert isinstance(patch, FluxControlLoRALayer)
|
||||
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
|
||||
setattr(
|
||||
module_to_patch,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
|
||||
module_param += param_weight.to(dtype=dtype)
|
||||
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
9
invokeai/backend/patches/pad_with_zeros.py
Normal file
9
invokeai/backend/patches/pad_with_zeros.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
|
||||
"""Pad a weight tensor with zeros to match the target shape."""
|
||||
expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device)
|
||||
slices = tuple(slice(0, dim) for dim in orig_weight.shape)
|
||||
expanded_weight[slices] = orig_weight
|
||||
return expanded_weight
|
||||
@@ -43,11 +43,6 @@ class BaseSidecarWrapper(torch.nn.Module):
|
||||
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
orig_param = self._orig_module.get_parameter(param_name)
|
||||
# TODO(ryand): Move shape handling down into the patch.
|
||||
if orig_param.shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(orig_param.shape)
|
||||
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
@@ -36,12 +37,19 @@ class LinearSidecarWrapper(BaseSidecarWrapper):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : self.orig_module.weight.shape[1]]
|
||||
output = self.orig_module(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in self._patches_and_weights:
|
||||
if isinstance(patch, LoRALayer):
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += self._lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += self._lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += self._concatenated_lora_forward(input, patch, patch_weight)
|
||||
|
||||
25
tests/backend/patches/layers/test_flux_control_lora_layer.py
Normal file
25
tests/backend/patches/layers/test_flux_control_lora_layer.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
|
||||
|
||||
def test_flux_control_lora_layer_get_parameters():
|
||||
"""Test getting weight and bias parameters from FluxControlLoRALayer."""
|
||||
small_in_features = 4
|
||||
big_in_features = 8
|
||||
out_features = 16
|
||||
rank = 4
|
||||
alpha = 16.0
|
||||
layer = FluxControlLoRALayer(
|
||||
up=torch.ones(out_features, rank), mid=None, down=torch.ones(rank, big_in_features), alpha=alpha, bias=None
|
||||
)
|
||||
|
||||
# Create mock original module
|
||||
orig_module = torch.nn.Linear(small_in_features, out_features)
|
||||
|
||||
# Test that get_parameters() behaves as expected in spite of the difference in in_features shapes.
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
assert "weight" in params
|
||||
assert params["weight"].shape == (out_features, big_in_features)
|
||||
assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha)
|
||||
assert "bias" not in params # No bias in this case
|
||||
@@ -3,8 +3,10 @@ import copy
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.full_layer import FullLayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
@@ -139,3 +141,42 @@ def test_linear_sidecar_wrapper_full_layer():
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = full_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
def test_linear_sidecar_wrapper_flux_control_lora_layer():
|
||||
# Create a linear layer.
|
||||
orig_in_features = 10
|
||||
out_features = 40
|
||||
linear = torch.nn.Linear(orig_in_features, out_features)
|
||||
|
||||
# Create a FluxControlLoRALayer.
|
||||
patched_in_features = 20
|
||||
rank = 4
|
||||
lora_layer = FluxControlLoRALayer(
|
||||
up=torch.randn(out_features, rank),
|
||||
mid=None,
|
||||
down=torch.randn(rank, patched_in_features),
|
||||
alpha=1.0,
|
||||
bias=torch.randn(out_features),
|
||||
)
|
||||
|
||||
# Patch the FluxControlLoRALayer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
# Expand the existing weight.
|
||||
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
|
||||
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
|
||||
# Expand the existing bias.
|
||||
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
|
||||
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
|
||||
# Add the residuals.
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
||||
|
||||
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, patched_in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
Reference in New Issue
Block a user