mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 18:57:55 -05:00
A unit tests for LinearSidecarWrapper (and fix a bug).
This commit is contained in:
@@ -43,9 +43,10 @@ class BaseSidecarWrapper(torch.nn.Module):
|
||||
layer_params = patch.get_parameters(self._orig_module)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
orig_param = self._orig_module.get_parameter(param_name)
|
||||
# TODO(ryand): Move shape handling down into the patch.
|
||||
if params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(self._orig_module.get_parameter(param_name).shape)
|
||||
if orig_param.shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(orig_param.shape)
|
||||
|
||||
# TODO(ryand): Move scale handling down into the patch.
|
||||
scale = 1
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_linear_sidecar_wrapper_lora():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_linear_sidecar_wrapper_concatenated_lora():
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a LinearSidecarWrapper.
|
||||
lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
Reference in New Issue
Block a user