Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLayers used by FLUX structural control LoRAs.

This commit is contained in:
Ryan Dick
2024-12-13 21:23:43 +00:00
parent c76a448846
commit 606d58d7db
5 changed files with 59 additions and 8 deletions

View File

@@ -32,11 +32,11 @@ def test_set_parameter_layer_to(device: str):
layer = SetParameterLayer(param_name="weight", weight=target_weight)
# SetParameterLayer should be initialized on CPU.
assert layer._weight.device.type == "cpu" # type: ignore
assert layer.weight.device.type == "cpu" # type: ignore
# Move to device.
layer.to(device=torch.device(device))
assert layer._weight.device.type == device # type: ignore
assert layer.weight.device.type == device # type: ignore
def test_set_parameter_layer_calc_size():

View File

@@ -0,0 +1,23 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
def test_flux_rms_norm_sidecar_wrapper():
# Create a RMSNorm layer.
dim = 10
rms_norm = torch.nn.RMSNorm(dim)
# Create a SetParameterLayer.
new_scale = torch.randn(dim)
set_parameter_layer = SetParameterLayer("scale", new_scale)
# Create a FluxRMSNormSidecarWrapper.
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
# Run the FluxRMSNormSidecarWrapper.
input = torch.randn(1, dim)
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
output_wrapped = rms_norm_wrapped(input)
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)