Merge branch 'main' into lstein/recall-reference-images

This commit is contained in:
Lincoln Stein
2026-04-20 16:38:36 -04:00
committed by GitHub
4 changed files with 203 additions and 13 deletions

View File

@@ -7,12 +7,25 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
@@ -25,13 +38,15 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
device=input.device,
)
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
residual_weight = self._cast_tensor_for_input(aggregated_param_residuals.get("weight", None), input)
residual_bias = self._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
weight = add_nullable_tensors(weight, residual_weight)
bias = add_nullable_tensors(bias, residual_bias)
return self._conv_forward(input, weight, bias)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
return self._conv_forward(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -39,5 +54,21 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
elif input.is_floating_point() and (
(
self.weight.is_floating_point()
and not isinstance(self.weight, GGMLTensor)
and self.weight.dtype != input.dtype
)
or (
self.bias is not None
and self.bias.is_floating_point()
and not isinstance(self.bias, GGMLTensor)
and self.bias.dtype != input.dtype
)
):
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
return self._conv_forward(input, weight, bias)
else:
return super().forward(input)

View File

@@ -9,6 +9,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
@@ -57,28 +58,47 @@ def autocast_linear_forward_sidecar_patches(
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
weight, bias = orig_module._cast_weight_bias_for_input(input)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
residual_weight = orig_module._cast_tensor_for_input(aggregated_param_residuals["weight"], input)
residual_bias = orig_module._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
assert residual_weight is not None
output += torch.nn.functional.linear(input, residual_weight, residual_bias)
return output
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor
def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
assert weight is not None
return weight, bias
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight, bias = self._cast_weight_bias_for_input(input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -86,5 +106,16 @@ class CustomLinear(torch.nn.Linear, CustomModuleMixin):
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
elif input.is_floating_point() and (
(self.weight.is_floating_point() and self.weight.dtype != input.dtype)
or (
self.bias is not None
and self.bias.is_floating_point()
and not isinstance(self.bias, GGMLTensor)
and self.bias.dtype != input.dtype
)
):
weight, bias = self._cast_weight_bias_for_input(input)
return torch.nn.functional.linear(input, weight, bias)
else:
return super().forward(input)

View File

@@ -49,7 +49,9 @@ class CustomModuleMixin:
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
pass
elif type(param) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /

View File

@@ -1,4 +1,5 @@
import copy
from collections.abc import Callable
import gguf
import pytest
@@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module):
return unwrap_custom_layer(layer, orig_layer_type)
class ZeroParamPatch(BaseLayerPatch):
"""A minimal parameter patch that exercises the aggregated sidecar patch path."""
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
return {name: torch.zeros_like(param) for name, param in orig_parameters.items()}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
return self
def calc_size(self) -> int:
return 0
def _cpu_dtype_supported(
layer_factory: Callable[[], torch.nn.Module],
input_factory: Callable[[torch.dtype], torch.Tensor],
dtype: torch.dtype,
) -> bool:
try:
layer = layer_factory().to(dtype=dtype)
input_tensor = input_factory(dtype)
with torch.no_grad():
_ = layer(input_tensor)
return True
except (RuntimeError, TypeError, NotImplementedError):
return False
def _cpu_dtype_param(
dtype: torch.dtype,
layer_factory: Callable[[], torch.nn.Module],
input_factory: Callable[[torch.dtype], torch.Tensor],
):
supported = _cpu_dtype_supported(layer_factory, input_factory, dtype)
return pytest.param(
dtype,
id=str(dtype).removeprefix("torch."),
marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"),
)
LINEAR_CPU_MIXED_DTYPE_PARAMS = [
_cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
_cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
]
CONV2D_CPU_MIXED_DTYPE_PARAMS = [
_cpu_dtype_param(
torch.bfloat16,
lambda: torch.nn.Conv2d(8, 16, 3),
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
),
_cpu_dtype_param(
torch.float16,
lambda: torch.nn.Conv2d(8, 16, 3),
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
),
]
def test_isinstance(layer_under_test: LayerUnderTest):
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
orig_layer, _, _ = layer_under_test
@@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
input = torch.randn(2, 8, dtype=dtype)
output = layer(input)
assert output.dtype == input.dtype
assert output.shape == (2, 16)
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype):
layer = torch.nn.Linear(8, 16).to(dtype=dtype)
layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32))
layer = wrap_single_custom_layer(layer)
input = torch.randn(2, 8, dtype=dtype)
output = layer(input)
assert output.dtype == input.dtype
assert output.shape == (2, 16)
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
input = torch.randn(2, 8, 5, 5, dtype=dtype)
output = layer(input)
assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
layer.add_patch(ZeroParamPatch(), 1.0)
input = torch.randn(2, 8, dtype=dtype)
output = layer(input)
assert output.dtype == input.dtype
assert output.shape == (2, 16)
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
layer.add_patch(ZeroParamPatch(), 1.0)
input = torch.randn(2, 8, 5, 5, dtype=dtype)
output = layer(input)
assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)