Add LoRA wrapper layer.

This commit is contained in:
Ryan Dick
2024-12-09 15:17:50 +00:00
parent 9019026d6d
commit 93f2bc6118
2 changed files with 81 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
class LoRAModuleWrapper(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
self._lora_weights = lora_weights
@torch.no_grad()
def _get_lora_patched_parameters(self) -> dict[str, torch.Tensor]:
out_params: dict[str, torch.Tensor] = {}
for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
# If the parameter already exists in out_params, use that one. Otherwise, use original parameter.
if param_name not in out_params:
out_params[param_name] = self._orig_module.get_parameter(param_name)
if out_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(out_params[param_name].shape)
out_params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
return out_params
class LoRALinearWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
class LoRAConv1dWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
class LoRAConv2dWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))

View File

@@ -0,0 +1,35 @@
import copy
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
@torch.no_grad()
def test_lora_linear_wrapper():
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LoRALinearWrapper.
lora_wrapped = LoRALinearWrapper(linear, [lora_layer], [1.0])
# Run the LoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)