mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 16:05:13 -05:00
Delete old sidecar wrapper implementation. This functionality has moved into the custom layers.
This commit is contained in:
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class BaseSidecarWrapper(torch.nn.Module):
|
||||
"""A base class for sidecar wrappers.
|
||||
|
||||
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
|
||||
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
|
||||
the original module.
|
||||
|
||||
Sidecar wrappers are typically used over regular patches when:
|
||||
- The original module is quantized and so the weights can't be patched in the usual way.
|
||||
- The original module is on the CPU and modifying the weights would require backing up the original weights and
|
||||
doubling the CPU memory usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
|
||||
|
||||
@property
|
||||
def orig_module(self) -> torch.nn.Module:
|
||||
return self._orig_module
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the sidecar wrapper."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
|
||||
# parameters, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(
|
||||
dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight
|
||||
)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
||||
|
||||
def forward(self, *args, **kwargs): # type: ignore
|
||||
raise NotImplementedError()
|
||||
@@ -1,11 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv1dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
@@ -1,11 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv2dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
|
||||
"""A sidecar wrapper for a FLUX RMSNorm layer.
|
||||
|
||||
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
|
||||
the RMSNorm scale parameters.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)
|
||||
@@ -1,66 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class LinearSidecarWrapper(BaseSidecarWrapper):
|
||||
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
def _concatenated_lora_forward(
|
||||
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : self.orig_module.in_features]
|
||||
output = self.orig_module(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in self._patches_and_weights:
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += self._lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += self._lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += self._concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1,20 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
|
||||
if isinstance(orig_module, torch.nn.Linear):
|
||||
return LinearSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv1d):
|
||||
return Conv1dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv2d):
|
||||
return Conv2dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, RMSNorm):
|
||||
return FluxRMSNormSidecarWrapper(orig_module)
|
||||
else:
|
||||
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|
||||
Reference in New Issue
Block a user