Add support for FluxControlLoRALayer in CustomLinear layers and add a unit test for it.

This commit is contained in:
Ryan Dick
2024-12-27 21:00:47 +00:00
parent 5ee7405f97
commit ef970a1cdc
2 changed files with 102 additions and 13 deletions

View File

@@ -10,9 +10,12 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
unwrap_custom_layer,
wrap_custom_layer,
)
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
build_linear_8bit_lt_layer,
)
@@ -272,6 +275,7 @@ LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float
"linear_single_lora",
"linear_multiple_loras",
"linear_concatenated_lora",
"linear_flux_control_lora",
]
)
def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest:
@@ -338,6 +342,25 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU
input = torch.randn(1, in_features)
return (layer, [(concatenated_lora_layer, 0.7)], input, True)
elif layer_type == "linear_flux_control_lora":
# Create a linear layer.
orig_in_features = 10
out_features = 40
layer = torch.nn.Linear(orig_in_features, out_features)
# Create a FluxControlLoRALayer.
patched_in_features = 20
rank = 4
lora_layer = FluxControlLoRALayer(
up=torch.randn(out_features, rank),
mid=None,
down=torch.randn(rank, patched_in_features),
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, patched_in_features)
return (layer, [(lora_layer, 0.7)], input, True)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@@ -356,18 +379,21 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU
# Patch the LoRA layer into the linear layer.
layer_patched = copy.deepcopy(layer)
for patch, weight in patches:
patch.to(torch.device(device))
parameters = patch.get_parameters(layer_patched, weight=weight)
for param_name, param_weight in parameters.items():
module_param = getattr(layer_patched, param_name)
module_param.data += param_weight
LayerPatcher._apply_model_layer_patch(
module_to_patch=layer_patched,
module_to_patch_key="",
patch=patch,
patch_weight=weight,
original_weights=OriginalWeightsStorage(),
)
# Wrap the original layer in a custom layer and add the patch to it as a sidecar.
custom_layer = wrap_single_custom_layer(layer)
for patch, weight in patches:
patch.to(torch.device(device))
custom_layer.add_patch(patch, weight)
# Run inference with the original layer and the patched layer and assert they are equal.
output_patched = layer_patched(input)
output_custom = custom_layer(input)
assert torch.allclose(output_patched, output_custom)
assert torch.allclose(output_patched, output_custom, atol=1e-6)