Begin to consolidate the LoRA sidecar and LoRA layer wrapper implementations.

This commit is contained in:
Ryan Dick
2024-12-10 01:16:01 +00:00
parent 55dc762a91
commit 0c4a368555
2 changed files with 102 additions and 27 deletions

View File

@@ -1,9 +1,11 @@
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRAModuleWrapper(torch.nn.Module):
class LoRASidecarWrapper(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
super().__init__()
self._orig_module = orig_module
@@ -19,40 +21,113 @@ class LoRAModuleWrapper(torch.nn.Module):
self._lora_weights.append(lora_weight)
@torch.no_grad()
def _get_lora_patched_parameters(self, params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True):
def _get_lora_patched_parameters(
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
) -> dict[str, torch.Tensor]:
params: dict[str, torch.Tensor] = {}
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
if params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(params[param_name].shape)
if orig_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(orig_params[param_name].shape)
# NOTE: It is important that params[param_name] is not modified in-place, because we initialize it
# with the original parameter - which we don't want to modify. In other words,
# `out_params[param_name] += ...` would not work.
params[param_name] = params[param_name] + param_weight * (lora_layer.scale() * lora_weight)
if param_name not in params:
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
else:
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
return params
class LoRALinearWrapper(LoRAModuleWrapper):
class LoRALinearWrapper(LoRASidecarWrapper):
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Split the LoRA layers into those that have optimized implementations and those that don't.
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if isinstance(layer, optimized_layer_types)
]
non_optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if not isinstance(layer, optimized_layer_types)
]
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
residual = None
for lora_layer, lora_weight in optimized_layers:
if isinstance(lora_layer, LoRALayer):
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
else:
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
if residual is None:
residual = added_residual
else:
residual += added_residual
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
if non_optimized_layers:
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=unoptimized_layers,
lora_weights=unoptimized_weights,
)
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
if residual is None:
residual = added_residual
else:
residual += added_residual
return self.orig_module(input) + residual
class LoRAConv1dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return torch.nn.functional.linear(input, params["weight"], params["bias"])
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
class LoRAConv1dWrapper(LoRAModuleWrapper):
class LoRAConv2dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return torch.nn.functional.conv1d(input, params["weight"], params["bias"])
class LoRAConv2dWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
)
return torch.nn.functional.conv2d(input, params["weight"], params["bias"])
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))

View File

@@ -10,7 +10,7 @@ from invokeai.backend.lora.lora_layer_wrappers import (
LoRAConv1dWrapper,
LoRAConv2dWrapper,
LoRALinearWrapper,
LoRAModuleWrapper,
LoRASidecarWrapper,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
@@ -196,8 +196,8 @@ class LoRAPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Replace the original module with a LoRAModuleWrapper if it has not already been done.
if not isinstance(module, LoRAModuleWrapper):
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
if not isinstance(module, LoRASidecarWrapper):
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module)
original_modules[module_key] = module
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
@@ -211,7 +211,7 @@ class LoRAPatcher:
# Move the LoRA layer to the same device/dtype as the orig module.
layer.to(device=orig_module.weight.device, dtype=orig_module.weight.dtype)
# Add the LoRA wrapper layer to the LoRAModuleWrapper.
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(layer, patch_weight)
@staticmethod