Add unit tests for LoRALinearSidecarLayer and ConcatenatedLoRALinearSidecarLayer.

This commit is contained in:
Ryan Dick
2024-09-13 14:43:52 +00:00
committed by Kent Keirsey
parent 61d3d566de
commit ba3ba3c23a
4 changed files with 86 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional, Sequence
import torch
@@ -13,7 +13,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
def __init__(self, lora_layers: Sequence[LoRALayerBase], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = torch.nn.ModuleList(lora_layers)

View File

@@ -15,8 +15,6 @@ class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 3
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)

View File

@@ -0,0 +1,47 @@
import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
def test_concatenated_lora_linear_sidecar_layer():
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
layer.
"""
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=None))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
# Create a ConcatenatedLoRALinearSidecarLayer.
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
# equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -0,0 +1,37 @@
import copy
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
@torch.no_grad()
def test_lora_linear_sidecar_layer():
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=None)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
# Create a LoRALinearSidecarLayer.
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)