diff --git a/invokeai/backend/patches/layers/flux_control_lora_layer.py b/invokeai/backend/patches/layers/flux_control_lora_layer.py new file mode 100644 index 0000000000..142336a00a --- /dev/null +++ b/invokeai/backend/patches/layers/flux_control_lora_layer.py @@ -0,0 +1,19 @@ +import torch + +from invokeai.backend.patches.layers.lora_layer import LoRALayer + + +class FluxControlLoRALayer(LoRALayer): + """A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the + shapes don't match. + """ + + def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + """This overrides the base class behavior to skip the reshaping step.""" + scale = self.scale() + params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} + bias = self.get_bias(orig_module.bias) + if bias is not None: + params["bias"] = bias * (weight * scale) + + return params diff --git a/invokeai/backend/patches/layers/lora_layer_base.py b/invokeai/backend/patches/layers/lora_layer_base.py index ca7b2dfcfd..13669ad5d3 100644 --- a/invokeai/backend/patches/layers/lora_layer_base.py +++ b/invokeai/backend/patches/layers/lora_layer_base.py @@ -63,6 +63,13 @@ class LoRALayerBase(BaseLayerPatch): bias = self.get_bias(orig_module.bias) if bias is not None: params["bias"] = bias * (weight * scale) + + # Reshape all params to match the original module's shape. + for param_name, param_weight in params.items(): + orig_param = orig_module.get_parameter(param_name) + if param_weight.shape != orig_param.shape: + params[param_name] = param_weight.reshape(orig_param.shape) + return params @classmethod diff --git a/invokeai/backend/patches/lora_patcher.py b/invokeai/backend/patches/lora_patcher.py index cafda5313e..8d00272cdb 100644 --- a/invokeai/backend/patches/lora_patcher.py +++ b/invokeai/backend/patches/lora_patcher.py @@ -4,7 +4,9 @@ from typing import Dict, Iterable, Optional, Tuple import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.lora_model_raw import LoRAModelRaw +from invokeai.backend.patches.pad_with_zeros import pad_with_zeros from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper from invokeai.backend.util.devices import TorchDevice @@ -125,24 +127,18 @@ class LoRAPatcher: # Save original weight original_weights.save(param_key, module_param) - if module_param.shape != param_weight.shape: - if module_param.nelement() == param_weight.nelement(): - param_weight = param_weight.reshape(module_param.shape) - else: - # This condition was added to handle layers in FLUX control LoRAs. - # TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need - # to worry about this? - expanded_weight = torch.zeros_like( - param_weight, dtype=module_param.dtype, device=module_param.device - ) - slices = tuple(slice(0, dim) for dim in module_param.shape) - expanded_weight[slices] = module_param - setattr( - module_to_patch, - param_name, - torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad), - ) - module_param = expanded_weight + # HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the + # shape of the original layer. + if module_param.nelement() != param_weight.nelement(): + assert isinstance(patch, FluxControlLoRALayer) + expanded_weight = pad_with_zeros(module_param, param_weight.shape) + setattr( + module_to_patch, + param_name, + torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad), + ) + module_param = expanded_weight + module_param += param_weight.to(dtype=dtype) patch.to(device=TorchDevice.CPU_DEVICE) diff --git a/invokeai/backend/patches/pad_with_zeros.py b/invokeai/backend/patches/pad_with_zeros.py new file mode 100644 index 0000000000..a76b02f0b3 --- /dev/null +++ b/invokeai/backend/patches/pad_with_zeros.py @@ -0,0 +1,9 @@ +import torch + + +def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: + """Pad a weight tensor with zeros to match the target shape.""" + expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device) + slices = tuple(slice(0, dim) for dim in orig_weight.shape) + expanded_weight[slices] = orig_weight + return expanded_weight diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py index a7dfa8eae3..c22525bc95 100644 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py @@ -43,11 +43,6 @@ class BaseSidecarWrapper(torch.nn.Module): layer_params = patch.get_parameters(self._orig_module, weight=patch_weight) for param_name, param_weight in layer_params.items(): - orig_param = self._orig_module.get_parameter(param_name) - # TODO(ryand): Move shape handling down into the patch. - if orig_param.shape != param_weight.shape: - param_weight = param_weight.reshape(orig_param.shape) - if param_name not in params: params[param_name] = param_weight else: diff --git a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py index 7b6bfdcab0..186c644c7c 100644 --- a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py +++ b/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py @@ -2,6 +2,7 @@ import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper @@ -36,12 +37,19 @@ class LinearSidecarWrapper(BaseSidecarWrapper): def forward(self, input: torch.Tensor) -> torch.Tensor: # First, apply the original linear layer. + # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which + # change the linear layer's in_features. + orig_input = input + input = orig_input[..., : self.orig_module.weight.shape[1]] output = self.orig_module(input) # Then, apply layers for which we have optimized implementations. unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] for patch, patch_weight in self._patches_and_weights: - if isinstance(patch, LoRALayer): + if isinstance(patch, FluxControlLoRALayer): + # Note that we use the original input here, not the sliced input. + output += self._lora_forward(orig_input, patch, patch_weight) + elif isinstance(patch, LoRALayer): output += self._lora_forward(input, patch, patch_weight) elif isinstance(patch, ConcatenatedLoRALayer): output += self._concatenated_lora_forward(input, patch, patch_weight) diff --git a/tests/backend/patches/layers/test_flux_control_lora_layer.py b/tests/backend/patches/layers/test_flux_control_lora_layer.py new file mode 100644 index 0000000000..00590c3514 --- /dev/null +++ b/tests/backend/patches/layers/test_flux_control_lora_layer.py @@ -0,0 +1,25 @@ +import torch + +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer + + +def test_flux_control_lora_layer_get_parameters(): + """Test getting weight and bias parameters from FluxControlLoRALayer.""" + small_in_features = 4 + big_in_features = 8 + out_features = 16 + rank = 4 + alpha = 16.0 + layer = FluxControlLoRALayer( + up=torch.ones(out_features, rank), mid=None, down=torch.ones(rank, big_in_features), alpha=alpha, bias=None + ) + + # Create mock original module + orig_module = torch.nn.Linear(small_in_features, out_features) + + # Test that get_parameters() behaves as expected in spite of the difference in in_features shapes. + params = layer.get_parameters(orig_module, weight=1.0) + assert "weight" in params + assert params["weight"].shape == (out_features, big_in_features) + assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha) + assert "bias" not in params # No bias in this case diff --git a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py index 1109f0f2a1..607f364dcd 100644 --- a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py +++ b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py @@ -3,8 +3,10 @@ import copy import torch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.layers.full_layer import FullLayer from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.pad_with_zeros import pad_with_zeros from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper @@ -139,3 +141,42 @@ def test_linear_sidecar_wrapper_full_layer(): output_patched = linear_patched(input) output_wrapped = full_wrapped(input) assert torch.allclose(output_patched, output_wrapped, atol=1e-6) + + +def test_linear_sidecar_wrapper_flux_control_lora_layer(): + # Create a linear layer. + orig_in_features = 10 + out_features = 40 + linear = torch.nn.Linear(orig_in_features, out_features) + + # Create a FluxControlLoRALayer. + patched_in_features = 20 + rank = 4 + lora_layer = FluxControlLoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, patched_in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + # Patch the FluxControlLoRALayer into the linear layer. + linear_patched = copy.deepcopy(linear) + # Expand the existing weight. + expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features])) + linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad) + # Expand the existing bias. + expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features])) + linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad) + # Add the residuals. + linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() + linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() + + # Create a LinearSidecarWrapper. + lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) + + # Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. + input = torch.randn(1, patched_in_features) + output_patched = linear_patched(input) + output_wrapped = lora_wrapped(input) + assert torch.allclose(output_patched, output_wrapped, atol=1e-6)