mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add sidecar wrapper for FLUX RMSNorm layers to support SetParameterLayers used by FLUX structural control LoRAs.
This commit is contained in:
@@ -32,11 +32,11 @@ def test_set_parameter_layer_to(device: str):
|
||||
layer = SetParameterLayer(param_name="weight", weight=target_weight)
|
||||
|
||||
# SetParameterLayer should be initialized on CPU.
|
||||
assert layer._weight.device.type == "cpu" # type: ignore
|
||||
assert layer.weight.device.type == "cpu" # type: ignore
|
||||
|
||||
# Move to device.
|
||||
layer.to(device=torch.device(device))
|
||||
assert layer._weight.device.type == device # type: ignore
|
||||
assert layer.weight.device.type == device # type: ignore
|
||||
|
||||
|
||||
def test_set_parameter_layer_calc_size():
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
|
||||
|
||||
def test_flux_rms_norm_sidecar_wrapper():
|
||||
# Create a RMSNorm layer.
|
||||
dim = 10
|
||||
rms_norm = torch.nn.RMSNorm(dim)
|
||||
|
||||
# Create a SetParameterLayer.
|
||||
new_scale = torch.randn(dim)
|
||||
set_parameter_layer = SetParameterLayer("scale", new_scale)
|
||||
|
||||
# Create a FluxRMSNormSidecarWrapper.
|
||||
rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)])
|
||||
|
||||
# Run the FluxRMSNormSidecarWrapper.
|
||||
input = torch.randn(1, dim)
|
||||
expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6)
|
||||
output_wrapped = rms_norm_wrapped(input)
|
||||
assert torch.allclose(output_wrapped, expected_output, atol=1e-6)
|
||||
Reference in New Issue
Block a user