Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLayers used by FLUX structural control LoRAs.

This commit is contained in:
Ryan Dick
2024-12-13 21:23:43 +00:00
parent c76a448846
commit 606d58d7db
5 changed files with 59 additions and 8 deletions

View File

@@ -11,15 +11,15 @@ class SetParameterLayer(BaseLayerPatch):
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__()
self._weight = weight
self._param_name = param_name
self.weight = weight
self.param_name = param_name
def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]:
diff = self._weight - orig_module.get_parameter(self._param_name)
return {self._param_name: diff}
diff = self.weight - orig_module.get_parameter(self.param_name)
return {self.param_name: diff}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._weight = self._weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return calc_tensor_size(self._weight)
return calc_tensor_size(self.weight)

View File

@@ -0,0 +1,24 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
"""A sidecar wrapper for a FLUX RMSNorm layer.
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
the RMSNorm scale parameters.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)

View File

@@ -1,7 +1,9 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
@@ -12,5 +14,7 @@ def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.M
return Conv1dSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv2d):
return Conv2dSidecarWrapper(orig_module)
elif isinstance(orig_module, RMSNorm):
return FluxRMSNormSidecarWrapper(orig_module)
else:
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")

View File

@@ -32,11 +32,11 @@ def test_set_parameter_layer_to(device: str):
layer = SetParameterLayer(param_name="weight", weight=target_weight)
# SetParameterLayer should be initialized on CPU.
assert layer._weight.device.type == "cpu" # type: ignore
assert layer.weight.device.type == "cpu" # type: ignore
# Move to device.
layer.to(device=torch.device(device))
assert layer._weight.device.type == device # type: ignore
assert layer.weight.device.type == device # type: ignore
def test_set_parameter_layer_calc_size():

View File

@@ -0,0 +1,23 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
def test_flux_rms_norm_sidecar_wrapper():
# Create a RMSNorm layer.
dim = 10
rms_norm = torch.nn.RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Create a FluxRMSNormSidecarWrapper.
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
# Run the FluxRMSNormSidecarWrapper.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_wrapped = rms_norm_wrapped(input)
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)