Push LoRA layer reshaping down into the patch layers and add a new FluxControlLoRALayer type.

This commit is contained in:
Ryan Dick
2024-12-14 01:00:22 +00:00
parent fe09f2d27a
commit 37e3089457
8 changed files with 124 additions and 24 deletions

View File

@@ -0,0 +1,25 @@
import torch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
def test_flux_control_lora_layer_get_parameters():
"""Test getting weight and bias parameters from FluxControlLoRALayer."""
small_in_features = 4
big_in_features = 8
out_features = 16
rank = 4
alpha = 16.0
layer = FluxControlLoRALayer(
up=torch.ones(out_features, rank), mid=None, down=torch.ones(rank, big_in_features), alpha=alpha, bias=None
)
# Create mock original module
orig_module = torch.nn.Linear(small_in_features, out_features)
# Test that get_parameters() behaves as expected in spite of the difference in in_features shapes.
params = layer.get_parameters(orig_module, weight=1.0)
assert "weight" in params
assert params["weight"].shape == (out_features, big_in_features)
assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha)
assert "bias" not in params # No bias in this case

View File

@@ -3,8 +3,10 @@ import copy
import torch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
@@ -139,3 +141,42 @@ def test_linear_sidecar_wrapper_full_layer():
output_patched = linear_patched(input)
output_wrapped = full_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_linear_sidecar_wrapper_flux_control_lora_layer():
# Create a linear layer.
orig_in_features = 10
out_features = 40
linear = torch.nn.Linear(orig_in_features, out_features)
# Create a FluxControlLoRALayer.
patched_in_features = 20
rank = 4
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, patched_in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
# Patch the FluxControlLoRALayer into the linear layer.
linear_patched = copy.deepcopy(linear)
# Expand the existing weight.
expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features]))
linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad)
# Expand the existing bias.
expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features]))
linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad)
# Add the residuals.
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, patched_in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)