Compare commits

...

1 Commits

Author SHA1 Message Date
Ryan Dick
f607ff4461 Add mark_flaky_mps_github_action test decorator. 2025-01-17 15:58:09 -05:00
2 changed files with 18 additions and 7 deletions

View File

@@ -1,5 +1,3 @@
import os
import gguf
import pytest
import torch
@@ -9,6 +7,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
remove_custom_layers_from_model,
)
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
from tests.mark_flaky_mps_github_action import mark_flaky_mps_github_action_test
try:
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
@@ -53,14 +52,10 @@ def model(request: pytest.FixtureRequest) -> torch.nn.Module:
raise ValueError(f"Invalid quantization type: {request.param}")
@mark_flaky_mps_github_action_test
@cuda_and_mps
@torch.no_grad()
def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.nn.Module):
# Skip this test with MPS on GitHub Actions. It fails but I haven't taken the tie to figure out why. It passes
# locally on MacOS.
if os.environ.get("GITHUB_ACTIONS") == "true" and device.type == "mps":
pytest.skip("This test is flaky on GitHub Actions")
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())

View File

@@ -0,0 +1,16 @@
import os
import pytest
import torch
IS_GITHUB_ACTION = os.environ.get("GITHUB_ACTION") == "true"
HAS_MPS_DEVICE = torch.backends.mps.is_available()
# Some tests that use MPS are flaky on Github Actions.
# Specifically, they fail with `MPS backend out of memory` even though there is plenty of memory available.
# I haven't taken the time to get to the bottom of this yet. The tests pass locally on MPS.
# There are several reports of similar issues
# (e.g. https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773).
mark_flaky_mps_github_action_test = pytest.mark.xfail(
condition=IS_GITHUB_ACTION and HAS_MPS_DEVICE, reason="This test is flaky on GitHub Actions with MPS.", strict=False
)