mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLayers used by FLUX structural control LoRAs.
This commit is contained in:
@@ -11,15 +11,15 @@ class SetParameterLayer(BaseLayerPatch):
|
||||
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__()
|
||||
self._weight = weight
|
||||
self._param_name = param_name
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
diff = self._weight - orig_module.get_parameter(self._param_name)
|
||||
return {self._param_name: diff}
|
||||
diff = self.weight - orig_module.get_parameter(self.param_name)
|
||||
return {self.param_name: diff}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._weight = self._weight.to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return calc_tensor_size(self._weight)
|
||||
return calc_tensor_size(self.weight)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
|
||||
"""A sidecar wrapper for a FLUX RMSNorm layer.
|
||||
|
||||
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
|
||||
the RMSNorm scale parameters.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
@@ -12,5 +14,7 @@ def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.M
|
||||
return Conv1dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv2d):
|
||||
return Conv2dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, RMSNorm):
|
||||
return FluxRMSNormSidecarWrapper(orig_module)
|
||||
else:
|
||||
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|
||||
|
||||
@@ -32,11 +32,11 @@ def test_set_parameter_layer_to(device: str):
|
||||
layer = SetParameterLayer(param_name="weight", weight=target_weight)
|
||||
|
||||
# SetParameterLayer should be initialized on CPU.
|
||||
assert layer._weight.device.type == "cpu" # type: ignore
|
||||
assert layer.weight.device.type == "cpu" # type: ignore
|
||||
|
||||
# Move to device.
|
||||
layer.to(device=torch.device(device))
|
||||
assert layer._weight.device.type == device # type: ignore
|
||||
assert layer.weight.device.type == device # type: ignore
|
||||
|
||||
|
||||
def test_set_parameter_layer_calc_size():
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
|
||||
|
||||
def test_flux_rms_norm_sidecar_wrapper():
|
||||
# Create a RMSNorm layer.
|
||||
dim = 10
|
||||
rms_norm = torch.nn.RMSNorm(dim)
|
||||
|
||||
# Create a SetParameterLayer.
|
||||
new_scale = torch.randn(dim)
|
||||
set_parameter_layer = SetParameterLayer("scale", new_scale)
|
||||
|
||||
# Create a FluxRMSNormSidecarWrapper.
|
||||
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
|
||||
|
||||
# Run the FluxRMSNormSidecarWrapper.
|
||||
input = torch.randn(1, dim)
|
||||
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
|
||||
output_wrapped = rms_norm_wrapped(input)
|
||||
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)
|
||||
Reference in New Issue
Block a user