test: regression coverage for DoRA + partial-loading + CPU→device autocast

Adds targeted coverage for the bug fixed in a0a87212 (#8624, PR #9063):

- test_aggregate_patch_parameters_preserves_plain_tensor_with_dora:
  CPU-only unit test that feeds a plain torch.Tensor (as handed in by
  _cast_weight_bias_for_input) into _aggregate_patch_parameters with a
  DoRA patch. Pre-fix, the tensor was replaced by a meta-device dummy,
  tripping DoRA's quantization guard.

- "single_dora" variant in the patch_under_test fixture: exercises the
  full CUDA/MPS autocast hot path via
  test_linear_sidecar_patches_with_autocast_from_cpu_to_device.
This commit is contained in:
Alexander Eichhorn
2026-04-22 00:53:38 +02:00
parent 87d0ca42e4
commit 0faf467506

View File

@@ -14,6 +14,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
)
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.dora_layer import DoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
@@ -346,6 +347,7 @@ PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor]
"concatenated_lora",
"flux_control_lora",
"single_lokr",
"single_dora",
]
)
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
@@ -432,6 +434,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
)
input = torch.randn(1, in_features)
return ([(lokr_layer, 0.7)], input)
elif layer_type == "single_dora":
# Regression coverage for #8624: DoRA + partial-loading + CPU->device autocast.
# Scaled down so the patched weight stays well-conditioned for allclose comparisons.
# dora_scale has shape (1, in_features) to broadcast against direction_norm in
# DoRALayer.get_weight — see dora_layer.py:74-82.
dora_layer = DoRALayer(
up=torch.randn(out_features, rank) * 0.01,
down=torch.randn(rank, in_features) * 0.01,
dora_scale=torch.ones(1, in_features),
alpha=1.0,
bias=torch.randn(out_features) * 0.01,
)
input = torch.randn(1, in_features)
return ([(dora_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
@@ -676,3 +692,45 @@ def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)
@torch.no_grad()
def test_aggregate_patch_parameters_preserves_plain_tensor_with_dora():
"""Regression test for #8624: when partial-loading autocasts a CPU Parameter onto the
compute device, cast_to_device returns a plain torch.Tensor (not a Parameter). The
aggregator must treat that as a real tensor and not substitute a meta-device dummy —
otherwise DoRA's quantization guard falsely triggers on non-quantized base models.
This test is CPU-only and simulates the hand-off by constructing a plain torch.Tensor
directly; the equivalent CUDA/MPS E2E flow is exercised by the "single_dora" variant
of test_linear_sidecar_patches_with_autocast_from_cpu_to_device.
"""
layer = wrap_single_custom_layer(torch.nn.Linear(32, 64))
rank = 4
dora_patch = DoRALayer(
up=torch.randn(64, rank) * 0.01,
down=torch.randn(rank, 32) * 0.01,
dora_scale=torch.ones(1, 32),
alpha=1.0,
bias=None,
)
# Plain torch.Tensor — the shape _cast_weight_bias_for_input hands into
# _aggregate_patch_parameters after autocasting a Parameter across devices.
plain_weight = torch.randn(64, 32)
assert type(plain_weight) is torch.Tensor
orig_params = {"weight": plain_weight}
params = layer._aggregate_patch_parameters(
patches_and_weights=[(dora_patch, 1.0)],
orig_params=orig_params,
device=torch.device("cpu"),
)
# Pre-fix, orig_params["weight"] would have been replaced by a meta-device dummy,
# causing DoRALayer.get_parameters to raise "not compatible with DoRA patches".
assert orig_params["weight"].device.type == "cpu"
assert params["weight"].shape == (64, 32)
assert params["weight"].device.type == "cpu"
assert not torch.isnan(params["weight"]).any()