mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add LoRA wrapper layer.
This commit is contained in:
46
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
46
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
|
||||
|
||||
class LoRAModuleWrapper(torch.nn.Module):
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
self._lora_weights = lora_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_lora_patched_parameters(self) -> dict[str, torch.Tensor]:
|
||||
out_params: dict[str, torch.Tensor] = {}
|
||||
for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True):
|
||||
layer_params = lora_layer.get_parameters(self._orig_module)
|
||||
for param_name, param_weight in layer_params.items():
|
||||
# If the parameter already exists in out_params, use that one. Otherwise, use original parameter.
|
||||
if param_name not in out_params:
|
||||
out_params[param_name] = self._orig_module.get_parameter(param_name)
|
||||
|
||||
if out_params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(out_params[param_name].shape)
|
||||
|
||||
out_params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
|
||||
|
||||
return out_params
|
||||
|
||||
|
||||
class LoRALinearWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
|
||||
|
||||
|
||||
class LoRAConv1dWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
|
||||
|
||||
|
||||
class LoRAConv2dWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
|
||||
35
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
35
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_lora_linear_wrapper():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LoRALinearWrapper.
|
||||
lora_wrapped = LoRALinearWrapper(linear, [lora_layer], [1.0])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
Reference in New Issue
Block a user