From 6fd9b0a274aa55ea0a319be7bbc6caa09cabd4d3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 17:33:08 +0000 Subject: [PATCH] Delete old sidecar wrapper implementation. This functionality has moved into the custom layers. --- .../patches/sidecar_wrappers/__init__.py | 0 .../sidecar_wrappers/base_sidecar_wrapper.py | 56 ------ .../conv1d_sidecar_wrapper.py | 11 -- .../conv2d_sidecar_wrapper.py | 11 -- .../flux_rms_norm_sidecar_wrapper.py | 24 --- .../linear_sidecar_wrapper.py | 66 ------- .../backend/patches/sidecar_wrappers/utils.py | 20 -- .../test_flux_rms_norm_sidecar_wrapper.py | 23 --- .../test_linear_sidecar_wrapper.py | 182 ------------------ 9 files changed, 393 deletions(-) delete mode 100644 invokeai/backend/patches/sidecar_wrappers/__init__.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/utils.py delete mode 100644 tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py delete mode 100644 tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py diff --git a/invokeai/backend/patches/sidecar_wrappers/__init__.py b/invokeai/backend/patches/sidecar_wrappers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py deleted file mode 100644 index 46d69bbe91..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch - - -class BaseSidecarWrapper(torch.nn.Module): - """A base class for sidecar wrappers. - - A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a - list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying - the original module. - - Sidecar wrappers are typically used over regular patches when: - - The original module is quantized and so the weights can't be patched in the usual way. - - The original module is on the CPU and modifying the weights would require backing up the original weights and - doubling the CPU memory usage. - """ - - def __init__( - self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None - ): - super().__init__() - self._orig_module = orig_module - self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights - - @property - def orig_module(self) -> torch.nn.Module: - return self._orig_module - - def add_patch(self, patch: BaseLayerPatch, patch_weight: float): - """Add a patch to the sidecar wrapper.""" - self._patches_and_weights.append((patch, patch_weight)) - - def _aggregate_patch_parameters( - self, patches_and_weights: list[tuple[BaseLayerPatch, float]] - ) -> dict[str, torch.Tensor]: - """Helper function that aggregates the parameters from all patches into a single dict.""" - params: dict[str, torch.Tensor] = {} - - for patch, patch_weight in patches_and_weights: - # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # parameters, this might fail or return incorrect results. - layer_params = patch.get_parameters( - dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight - ) - - for param_name, param_weight in layer_params.items(): - if param_name not in params: - params[param_name] = param_weight - else: - params[param_name] += param_weight - - return params - - def forward(self, *args, **kwargs): # type: ignore - raise NotImplementedError() diff --git a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py deleted file mode 100644 index 7877aae8c7..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv1dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py deleted file mode 100644 index d9bb713534..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv2dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index 34c3b9b369..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class FluxRMSNormSidecarWrapper(BaseSidecarWrapper): - """A sidecar wrapper for a FLUX RMSNorm layer. - - This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite - the RMSNorm scale parameters. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # Given the narrow focus of this wrapper, we only support a very particular patch configuration: - assert len(self._patches_and_weights) == 1 - patch, _patch_weight = self._patches_and_weights[0] - assert isinstance(patch, SetParameterLayer) - assert patch.param_name == "scale" - - # Apply the patch. - # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should - # be handled. - return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6) diff --git a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py deleted file mode 100644 index 98775b9feb..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class LinearSidecarWrapper(BaseSidecarWrapper): - def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear LoRALayer.""" - x = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x = torch.nn.functional.linear(x, lora_layer.mid) - x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) - x *= lora_weight * lora_layer.scale() - return x - - def _concatenated_lora_forward( - self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float - ) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer.""" - x_chunks: list[torch.Tensor] = [] - for lora_layer in concatenated_lora_layer.lora_layers: - x_chunk = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) - x_chunk *= lora_weight * lora_layer.scale() - x_chunks.append(x_chunk) - - # TODO(ryand): Generalize to support concat_axis != 0. - assert concatenated_lora_layer.concat_axis == 0 - x = torch.cat(x_chunks, dim=-1) - return x - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # First, apply the original linear layer. - # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which - # change the linear layer's in_features. - orig_input = input - input = orig_input[..., : self.orig_module.in_features] - output = self.orig_module(input) - - # Then, apply layers for which we have optimized implementations. - unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] - for patch, patch_weight in self._patches_and_weights: - if isinstance(patch, FluxControlLoRALayer): - # Note that we use the original input here, not the sliced input. - output += self._lora_forward(orig_input, patch, patch_weight) - elif isinstance(patch, LoRALayer): - output += self._lora_forward(input, patch, patch_weight) - elif isinstance(patch, ConcatenatedLoRALayer): - output += self._concatenated_lora_forward(input, patch, patch_weight) - else: - unprocessed_patches_and_weights.append((patch, patch_weight)) - - # Finally, apply any remaining patches. - if len(unprocessed_patches_and_weights) > 0: - aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights) - output += torch.nn.functional.linear( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) - - return output diff --git a/invokeai/backend/patches/sidecar_wrappers/utils.py b/invokeai/backend/patches/sidecar_wrappers/utils.py deleted file mode 100644 index 6a71213b09..0000000000 --- a/invokeai/backend/patches/sidecar_wrappers/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - -from invokeai.backend.flux.modules.layers import RMSNorm -from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module: - if isinstance(orig_module, torch.nn.Linear): - return LinearSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv1d): - return Conv1dSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv2d): - return Conv2dSidecarWrapper(orig_module) - elif isinstance(orig_module, RMSNorm): - return FluxRMSNormSidecarWrapper(orig_module) - else: - raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}") diff --git a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index ee0dce554f..0000000000 --- a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper - - -def test_flux_rms_norm_sidecar_wrapper(): - # Create a RMSNorm layer. - dim = 10 - rms_norm = torch.nn.RMSNorm(dim) - - # Create a SetParameterLayer. - new_scale = torch.randn(dim) - set_parameter_layer = SetParameterLayer("scale", new_scale) - - # Create a FluxRMSNormSidecarWrapper. - rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)]) - - # Run the FluxRMSNormSidecarWrapper. - input = torch.randn(1, dim) - expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) - output_wrapped = rms_norm_wrapped(input) - assert torch.allclose(output_wrapped, expected_output, atol=1e-6) diff --git a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py deleted file mode 100644 index 607f364dcd..0000000000 --- a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py +++ /dev/null @@ -1,182 +0,0 @@ -import copy - -import torch - -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.full_layer import FullLayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.pad_with_zeros import pad_with_zeros -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -@torch.no_grad() -def test_linear_sidecar_wrapper_lora(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a LoRA layer. - rank = 4 - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) - - # Patch the LoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_multiple_loras(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create two LoRA layers. - rank = 4 - lora_layer = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - lora_layer_2 = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - # We use different weights for the two LoRA layers to ensure this is working. - lora_weight = 1.0 - lora_weight_2 = 0.5 - - # Patch the LoRA layers into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight) - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight) - linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * ( - lora_layer_2.scale() * lora_weight_2 - ) - linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2) - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_concatenated_lora(): - # Create a linear layer. - in_features = 5 - sub_layer_out_features = [5, 10, 15] - linear = torch.nn.Linear(in_features, sum(sub_layer_out_features)) - - # Create a ConcatenatedLoRA layer. - rank = 4 - sub_layers: list[LoRALayer] = [] - for out_features in sub_layer_out_features: - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) - concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) - - # Patch the ConcatenatedLoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += ( - concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale() - ) - linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)]) - - # Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_full_layer(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a FullLayer. - full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features)) - - # Patch the FullLayer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += full_layer.get_weight(linear_patched.weight) - linear_patched.bias.data += full_layer.get_bias(linear_patched.bias) - - # Create a LinearSidecarWrapper. - full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)]) - - # Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = full_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_flux_control_lora_layer(): - # Create a linear layer. - orig_in_features = 10 - out_features = 40 - linear = torch.nn.Linear(orig_in_features, out_features) - - # Create a FluxControlLoRALayer. - patched_in_features = 20 - rank = 4 - lora_layer = FluxControlLoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, patched_in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - - # Patch the FluxControlLoRALayer into the linear layer. - linear_patched = copy.deepcopy(linear) - # Expand the existing weight. - expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features])) - linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad) - # Expand the existing bias. - expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features])) - linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad) - # Add the residuals. - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, patched_in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6)