A unit tests for LinearSidecarWrapper (and fix a bug).

This commit is contained in:
Ryan Dick
2024-12-13 18:47:28 +00:00
parent 443d838fd0
commit e2451ef5ca
2 changed files with 73 additions and 2 deletions

View File

@@ -43,9 +43,10 @@ class BaseSidecarWrapper(torch.nn.Module):
layer_params = patch.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
orig_param = self._orig_module.get_parameter(param_name)
# TODO(ryand): Move shape handling down into the patch.
if params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(self._orig_module.get_parameter(param_name).shape)
if orig_param.shape != param_weight.shape:
param_weight = param_weight.reshape(orig_param.shape)
# TODO(ryand): Move scale handling down into the patch.
scale = 1

View File

@@ -0,0 +1,70 @@
import copy
import torch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
@torch.no_grad()
def test_linear_sidecar_wrapper_lora():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
@torch.no_grad()
def test_linear_sidecar_wrapper_concatenated_lora():
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a LinearSidecarWrapper.
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)