mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into lstein/recall-reference-images
This commit is contained in:
@@ -7,12 +7,25 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
|
||||
tensor = cast_to_device(tensor, input.device)
|
||||
if (
|
||||
tensor is not None
|
||||
and input.is_floating_point()
|
||||
and tensor.is_floating_point()
|
||||
and not isinstance(tensor, GGMLTensor)
|
||||
and tensor.dtype != input.dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=input.dtype)
|
||||
return tensor
|
||||
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
weight = self._cast_tensor_for_input(self.weight, input)
|
||||
bias = self._cast_tensor_for_input(self.bias, input)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
@@ -25,13 +38,15 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
residual_weight = self._cast_tensor_for_input(aggregated_param_residuals.get("weight", None), input)
|
||||
residual_bias = self._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
|
||||
weight = add_nullable_tensors(weight, residual_weight)
|
||||
bias = add_nullable_tensors(bias, residual_bias)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
weight = self._cast_tensor_for_input(self.weight, input)
|
||||
bias = self._cast_tensor_for_input(self.bias, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
@@ -39,5 +54,21 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
elif input.is_floating_point() and (
|
||||
(
|
||||
self.weight.is_floating_point()
|
||||
and not isinstance(self.weight, GGMLTensor)
|
||||
and self.weight.dtype != input.dtype
|
||||
)
|
||||
or (
|
||||
self.bias is not None
|
||||
and self.bias.is_floating_point()
|
||||
and not isinstance(self.bias, GGMLTensor)
|
||||
and self.bias.dtype != input.dtype
|
||||
)
|
||||
):
|
||||
weight = self._cast_tensor_for_input(self.weight, input)
|
||||
bias = self._cast_tensor_for_input(self.bias, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
||||
@@ -9,6 +9,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
@@ -57,28 +58,47 @@ def autocast_linear_forward_sidecar_patches(
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
weight, bias = orig_module._cast_weight_bias_for_input(input)
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
|
||||
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
|
||||
)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
residual_weight = orig_module._cast_tensor_for_input(aggregated_param_residuals["weight"], input)
|
||||
residual_bias = orig_module._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
|
||||
assert residual_weight is not None
|
||||
output += torch.nn.functional.linear(input, residual_weight, residual_bias)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
|
||||
tensor = cast_to_device(tensor, input.device)
|
||||
if (
|
||||
tensor is not None
|
||||
and input.is_floating_point()
|
||||
and tensor.is_floating_point()
|
||||
and not isinstance(tensor, GGMLTensor)
|
||||
and tensor.dtype != input.dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=input.dtype)
|
||||
return tensor
|
||||
|
||||
def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
weight = self._cast_tensor_for_input(self.weight, input)
|
||||
bias = self._cast_tensor_for_input(self.bias, input)
|
||||
assert weight is not None
|
||||
return weight, bias
|
||||
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
weight, bias = self._cast_weight_bias_for_input(input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
@@ -86,5 +106,16 @@ class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
elif input.is_floating_point() and (
|
||||
(self.weight.is_floating_point() and self.weight.dtype != input.dtype)
|
||||
or (
|
||||
self.bias is not None
|
||||
and self.bias.is_floating_point()
|
||||
and not isinstance(self.bias, GGMLTensor)
|
||||
and self.bias.dtype != input.dtype
|
||||
)
|
||||
):
|
||||
weight, bias = self._cast_weight_bias_for_input(input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
||||
@@ -49,7 +49,9 @@ class CustomModuleMixin:
|
||||
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
|
||||
for param_name in orig_params.keys():
|
||||
param = orig_params[param_name]
|
||||
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
|
||||
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
|
||||
pass
|
||||
elif type(param) is torch.Tensor:
|
||||
pass
|
||||
elif type(param) is GGMLTensor:
|
||||
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
@@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module):
|
||||
return unwrap_custom_layer(layer, orig_layer_type)
|
||||
|
||||
|
||||
class ZeroParamPatch(BaseLayerPatch):
|
||||
"""A minimal parameter patch that exercises the aggregated sidecar patch path."""
|
||||
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
return {name: torch.zeros_like(param) for name, param in orig_parameters.items()}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
return self
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def _cpu_dtype_supported(
|
||||
layer_factory: Callable[[], torch.nn.Module],
|
||||
input_factory: Callable[[torch.dtype], torch.Tensor],
|
||||
dtype: torch.dtype,
|
||||
) -> bool:
|
||||
try:
|
||||
layer = layer_factory().to(dtype=dtype)
|
||||
input_tensor = input_factory(dtype)
|
||||
with torch.no_grad():
|
||||
_ = layer(input_tensor)
|
||||
return True
|
||||
except (RuntimeError, TypeError, NotImplementedError):
|
||||
return False
|
||||
|
||||
|
||||
def _cpu_dtype_param(
|
||||
dtype: torch.dtype,
|
||||
layer_factory: Callable[[], torch.nn.Module],
|
||||
input_factory: Callable[[torch.dtype], torch.Tensor],
|
||||
):
|
||||
supported = _cpu_dtype_supported(layer_factory, input_factory, dtype)
|
||||
return pytest.param(
|
||||
dtype,
|
||||
id=str(dtype).removeprefix("torch."),
|
||||
marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"),
|
||||
)
|
||||
|
||||
|
||||
LINEAR_CPU_MIXED_DTYPE_PARAMS = [
|
||||
_cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
|
||||
_cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
|
||||
]
|
||||
|
||||
|
||||
CONV2D_CPU_MIXED_DTYPE_PARAMS = [
|
||||
_cpu_dtype_param(
|
||||
torch.bfloat16,
|
||||
lambda: torch.nn.Conv2d(8, 16, 3),
|
||||
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
|
||||
),
|
||||
_cpu_dtype_param(
|
||||
torch.float16,
|
||||
lambda: torch.nn.Conv2d(8, 16, 3),
|
||||
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_isinstance(layer_under_test: LayerUnderTest):
|
||||
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
|
||||
orig_layer, _, _ = layer_under_test
|
||||
@@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
|
||||
|
||||
# Assert that the outputs with and without autocasting are the same.
|
||||
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
|
||||
@torch.no_grad()
|
||||
def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype):
|
||||
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
|
||||
input = torch.randn(2, 8, dtype=dtype)
|
||||
|
||||
output = layer(input)
|
||||
|
||||
assert output.dtype == input.dtype
|
||||
assert output.shape == (2, 16)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
|
||||
@torch.no_grad()
|
||||
def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype):
|
||||
layer = torch.nn.Linear(8, 16).to(dtype=dtype)
|
||||
layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32))
|
||||
layer = wrap_single_custom_layer(layer)
|
||||
input = torch.randn(2, 8, dtype=dtype)
|
||||
|
||||
output = layer(input)
|
||||
|
||||
assert output.dtype == input.dtype
|
||||
assert output.shape == (2, 16)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
|
||||
@torch.no_grad()
|
||||
def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype):
|
||||
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
|
||||
input = torch.randn(2, 8, 5, 5, dtype=dtype)
|
||||
|
||||
output = layer(input)
|
||||
|
||||
assert output.dtype == input.dtype
|
||||
assert output.shape == (2, 16, 3, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
|
||||
@torch.no_grad()
|
||||
def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
|
||||
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
|
||||
layer.add_patch(ZeroParamPatch(), 1.0)
|
||||
input = torch.randn(2, 8, dtype=dtype)
|
||||
|
||||
output = layer(input)
|
||||
|
||||
assert output.dtype == input.dtype
|
||||
assert output.shape == (2, 16)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
|
||||
@torch.no_grad()
|
||||
def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
|
||||
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
|
||||
layer.add_patch(ZeroParamPatch(), 1.0)
|
||||
input = torch.randn(2, 8, 5, 5, dtype=dtype)
|
||||
|
||||
output = layer(input)
|
||||
|
||||
assert output.dtype == input.dtype
|
||||
assert output.shape == (2, 16, 3, 3)
|
||||
|
||||
Reference in New Issue
Block a user