Finish consolidating LoRA sidecar wrapper implementations.

This commit is contained in:
Ryan Dick
2024-12-10 02:54:32 +00:00
parent 3d6b93efdd
commit 23f521dc7c
13 changed files with 36 additions and 348 deletions

View File

@@ -325,11 +325,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
LoRAPatcher.apply_lora_wrapper_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
else:

View File

@@ -3,9 +3,6 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_layer_wrappers import (
LoRAConv1dWrapper,
LoRAConv2dWrapper,
@@ -13,11 +10,6 @@ from invokeai.backend.lora.lora_layer_wrappers import (
LoRASidecarWrapper,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@@ -154,7 +146,7 @@ class LoRAPatcher:
):
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
runtime overhead compared to normal LoRA patching, but they enable:
- LoRA layers to be applied to quantization format that are quatnized at the tensor level.
- LoRA layers to be applied to quantized models
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
avoid doubling the memory requirements).
@@ -231,102 +223,6 @@ class LoRAPatcher:
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(layer, patch_weight)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_lora_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# Replace the original module with a LoRASidecarModule if it has not already been done.
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
lora_sidecar_module = module
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [])
original_modules[module_key] = module
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
# Add the LoRA sidecar layer to the LoRASidecarModule.
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
"""Split a module key into its parent key and module name.
@@ -356,21 +252,6 @@ class LoRAPatcher:
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear) or (
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:

View File

@@ -1,34 +0,0 @@
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
concatenated_lora_layer: ConcatenatedLoRALayer,
weight: float,
):
super().__init__()
self._concatenated_lora_layer = concatenated_lora_layer
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= self._weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert self._concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._concatenated_lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,27 +0,0 @@
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
lora_layer: LoRALayer,
weight: float,
):
super().__init__()
self._lora_layer = lora_layer
self._weight = weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.linear(x, self._lora_layer.down)
if self._lora_layer.mid is not None:
x = torch.nn.functional.linear(x, self._lora_layer.mid)
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
x *= self._weight * self._lora_layer.scale()
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,24 +0,0 @@
import torch
class LoRASidecarModule(torch.nn.Module):
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self.orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.orig_module(input)
for lora_layer in self._lora_layers:
x += lora_layer(input)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._orig_module.to(device=device, dtype=dtype)
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -1,49 +0,0 @@
import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
def test_concatenated_lora_linear_sidecar_layer():
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
layer.
"""
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a ConcatenatedLoRALinearSidecarLayer.
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
# equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -1,38 +0,0 @@
import copy
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
@torch.no_grad()
def test_lora_linear_sidecar_layer():
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
# Create a linear layer.
in_features = 10
out_features = 20
linear = torch.nn.Linear(in_features, out_features)
# Create a LoRA layer.
rank = 4
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
# Patch the LoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
# Create a LoRALinearSidecarLayer.
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_sidecar = linear_with_sidecar(input)
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)

View File

@@ -2,6 +2,7 @@ import copy
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
@@ -33,3 +34,36 @@ def test_lora_linear_wrapper():
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
def test_concatenated_lora_linear_wrapper():
# Create a linear layer.
in_features = 5
sub_layer_out_features = [5, 10, 15]
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
# Create a ConcatenatedLoRA layer.
rank = 4
sub_layers: list[LoRALayer] = []
for out_features in sub_layer_out_features:
down = torch.randn(rank, in_features)
up = torch.randn(out_features, rank)
bias = torch.randn(out_features)
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
# Patch the ConcatenatedLoRA layer into the linear layer.
linear_patched = copy.deepcopy(linear)
linear_patched.weight.data += (
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
)
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
# Create a LoRALinearWrapper.
lora_wrapped = LoRALinearWrapper(linear, [concatenated_lora_layer], [1.0])
# Run the ConcatenatedLoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
input = torch.randn(1, in_features)
output_patched = linear_patched(input)
output_wrapped = lora_wrapped(input)
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)

View File

@@ -109,56 +109,6 @@ def test_apply_lora_patches_change_device():
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
@pytest.mark.parametrize(
["device", "num_layers"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(
["device", "num_layers"],
[
@@ -239,13 +189,9 @@ def test_all_patching_methods_produce_same_output(num_layers: int):
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
output_lora_wrapper_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)