First pass at making custom layer patches work with weights streamed from the CPU to the GPU.

This commit is contained in:
Ryan Dick
2024-12-29 06:51:30 +00:00
parent 6d49ee839c
commit a8bef59699
6 changed files with 92 additions and 45 deletions

View File

@@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None))
return self._conv_forward(input, weight, bias)
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = self._aggregate_patch_parameters(
patches_and_weights=self._patches_and_weights,
orig_params=orig_params,
device=input.device,
)
return self._conv_forward(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)

View File

@@ -16,10 +16,12 @@ class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
scale = cast_to_device(patch.weight, x.device)
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6)
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
scale = cast_to_device(self.scale, x.device)

View File

@@ -1,3 +1,5 @@
import copy
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
@@ -55,6 +57,10 @@ def autocast_linear_forward_sidecar_patches(
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in patches_and_weights:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(input.device)
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += linear_lora_forward(orig_input, patch, patch_weight)
@@ -67,7 +73,14 @@ def autocast_linear_forward_sidecar_patches(
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@@ -1,3 +1,5 @@
import copy
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
@@ -34,15 +36,23 @@ class CustomModuleMixin:
return len(self._patches_and_weights)
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
self,
patches_and_weights: list[tuple[BaseLayerPatch, float]],
orig_params: dict[str, torch.Tensor],
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
if device is not None:
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
patch = copy.copy(patch)
patch.to(device)
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
# parameters, this might fail or return incorrect results.
layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params: