diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py index e2d9f423a1..40c77f1f96 100644 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py @@ -43,9 +43,10 @@ class BaseSidecarWrapper(torch.nn.Module): layer_params = patch.get_parameters(self._orig_module) for param_name, param_weight in layer_params.items(): + orig_param = self._orig_module.get_parameter(param_name) # TODO(ryand): Move shape handling down into the patch. - if params[param_name].shape != param_weight.shape: - param_weight = param_weight.reshape(self._orig_module.get_parameter(param_name).shape) + if orig_param.shape != param_weight.shape: + param_weight = param_weight.reshape(orig_param.shape) # TODO(ryand): Move scale handling down into the patch. scale = 1 diff --git a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py new file mode 100644 index 0000000000..5df728bdc6 --- /dev/null +++ b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py @@ -0,0 +1,70 @@ +import copy + +import torch + +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper + + +@torch.no_grad() +def test_linear_sidecar_wrapper_lora(): + # Create a linear layer. + in_features = 10 + out_features = 20 + linear = torch.nn.Linear(in_features, out_features) + + # Create a LoRA layer. + rank = 4 + down = torch.randn(rank, in_features) + up = torch.randn(out_features, rank) + bias = torch.randn(out_features) + lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) + + # Patch the LoRA layer into the linear layer. + linear_patched = copy.deepcopy(linear) + linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() + linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() + + # Create a LinearSidecarWrapper. + lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) + + # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. + input = torch.randn(1, in_features) + output_patched = linear_patched(input) + output_wrapped = lora_wrapped(input) + assert torch.allclose(output_patched, output_wrapped, atol=1e-6) + + +@torch.no_grad() +def test_linear_sidecar_wrapper_concatenated_lora(): + # Create a linear layer. + in_features = 5 + sub_layer_out_features = [5, 10, 15] + linear = torch.nn.Linear(in_features, sum(sub_layer_out_features)) + + # Create a ConcatenatedLoRA layer. + rank = 4 + sub_layers: list[LoRALayer] = [] + for out_features in sub_layer_out_features: + down = torch.randn(rank, in_features) + up = torch.randn(out_features, rank) + bias = torch.randn(out_features) + sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) + concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) + + # Patch the ConcatenatedLoRA layer into the linear layer. + linear_patched = copy.deepcopy(linear) + linear_patched.weight.data += ( + concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale() + ) + linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale() + + # Create a LinearSidecarWrapper. + lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)]) + + # Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. + input = torch.randn(1, in_features) + output_patched = linear_patched(input) + output_wrapped = lora_wrapped(input) + assert torch.allclose(output_patched, output_wrapped, atol=1e-6)