Add support for FluxControlLoRALayer in CustomLinear layers and add a unit test for it.

This commit is contained in:
Ryan Dick
2024-12-27 21:00:47 +00:00
parent 5ee7405f97
commit ef970a1cdc
2 changed files with 102 additions and 13 deletions

View File

@@ -4,17 +4,80 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def concatenated_lora_forward(
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def autocast_linear_forward_sidecar_patches(
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.Tensor:
"""A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer.
Compatible with both quantized and non-quantized Linear layers.
"""
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : orig_module.in_features]
output = orig_module._autocast_forward(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in patches_and_weights:
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += linear_lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += linear_lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
return output
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"])
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return torch.nn.functional.linear(input, weight, bias)
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)