From 606d58d7db28c4d58162454afc6629fde9d2f27d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 13 Dec 2024 21:23:43 +0000 Subject: [PATCH] Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLayers used by FLUX structural control LoRAs. --- .../patches/layers/set_parameter_layer.py | 12 +++++----- .../flux_rms_norm_sidecar_wrapper.py | 24 +++++++++++++++++++ .../backend/patches/sidecar_wrappers/utils.py | 4 ++++ .../layers/test_set_parameter_layer.py | 4 ++-- .../test_flux_rms_norm_sidecar_wrapper.py | 23 ++++++++++++++++++ 5 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py create mode 100644 tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py diff --git a/invokeai/backend/patches/layers/set_parameter_layer.py b/invokeai/backend/patches/layers/set_parameter_layer.py index 97a54bcec2..c55cd68d20 100644 --- a/invokeai/backend/patches/layers/set_parameter_layer.py +++ b/invokeai/backend/patches/layers/set_parameter_layer.py @@ -11,15 +11,15 @@ class SetParameterLayer(BaseLayerPatch): def __init__(self, param_name: str, weight: torch.Tensor): super().__init__() - self._weight = weight - self._param_name = param_name + self.weight = weight + self.param_name = param_name def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]: - diff = self._weight - orig_module.get_parameter(self._param_name) - return {self._param_name: diff} + diff = self.weight - orig_module.get_parameter(self.param_name) + return {self.param_name: diff} def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): - self._weight = self._weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype) def calc_size(self) -> int: - return calc_tensor_size(self._weight) + return calc_tensor_size(self.weight) diff --git a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py new file mode 100644 index 0000000000..34c3b9b369 --- /dev/null +++ b/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py @@ -0,0 +1,24 @@ +import torch + +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer +from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper + + +class FluxRMSNormSidecarWrapper(BaseSidecarWrapper): + """A sidecar wrapper for a FLUX RMSNorm layer. + + This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite + the RMSNorm scale parameters. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Given the narrow focus of this wrapper, we only support a very particular patch configuration: + assert len(self._patches_and_weights) == 1 + patch, _patch_weight = self._patches_and_weights[0] + assert isinstance(patch, SetParameterLayer) + assert patch.param_name == "scale" + + # Apply the patch. + # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should + # be handled. + return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6) diff --git a/invokeai/backend/patches/sidecar_wrappers/utils.py b/invokeai/backend/patches/sidecar_wrappers/utils.py index 188216efba..6a71213b09 100644 --- a/invokeai/backend/patches/sidecar_wrappers/utils.py +++ b/invokeai/backend/patches/sidecar_wrappers/utils.py @@ -1,7 +1,9 @@ import torch +from invokeai.backend.flux.modules.layers import RMSNorm from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper +from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper @@ -12,5 +14,7 @@ def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.M return Conv1dSidecarWrapper(orig_module) elif isinstance(orig_module, torch.nn.Conv2d): return Conv2dSidecarWrapper(orig_module) + elif isinstance(orig_module, RMSNorm): + return FluxRMSNormSidecarWrapper(orig_module) else: raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}") diff --git a/tests/backend/patches/layers/test_set_parameter_layer.py b/tests/backend/patches/layers/test_set_parameter_layer.py index 075f5a62ed..1e0a3f46f1 100644 --- a/tests/backend/patches/layers/test_set_parameter_layer.py +++ b/tests/backend/patches/layers/test_set_parameter_layer.py @@ -32,11 +32,11 @@ def test_set_parameter_layer_to(device: str): layer = SetParameterLayer(param_name="weight", weight=target_weight) # SetParameterLayer should be initialized on CPU. - assert layer._weight.device.type == "cpu" # type: ignore + assert layer.weight.device.type == "cpu" # type: ignore # Move to device. layer.to(device=torch.device(device)) - assert layer._weight.device.type == device # type: ignore + assert layer.weight.device.type == device # type: ignore def test_set_parameter_layer_calc_size(): diff --git a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py new file mode 100644 index 0000000000..ee0dce554f --- /dev/null +++ b/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py @@ -0,0 +1,23 @@ +import torch + +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer +from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper + + +def test_flux_rms_norm_sidecar_wrapper(): + # Create a RMSNorm layer. + dim = 10 + rms_norm = torch.nn.RMSNorm(dim) + + # Create a SetParameterLayer. + new_scale = torch.randn(dim) + set_parameter_layer = SetParameterLayer("scale", new_scale) + + # Create a FluxRMSNormSidecarWrapper. + rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)]) + + # Run the FluxRMSNormSidecarWrapper. + input = torch.randn(1, dim) + expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) + output_wrapped = rms_norm_wrapped(input) + assert torch.allclose(output_wrapped, expected_output, atol=1e-6)