diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f1bb504cbf..de17f66abc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -462,8 +462,8 @@ jobs: - name: Run CLOUD=1 Test run: | CLOUDDEV=CPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_jit.py - CLOUDDEV=GPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_jit.py - CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py + CLOUDDEV=GPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py + CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py - name: Test Optimization Helpers run: PYTHONPATH="." DEBUG=1 python3 extra/optimization/test_helpers.py - name: Test Action Space diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 568a17fcf8..574f3d67ff 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -6,7 +6,10 @@ from tinygrad.dtype import ImageDType from tinygrad.engine.realize import lower_schedule from tinygrad.helpers import prod, unwrap -@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU") +IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU") +REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "CLOUD" else Device['CLOUD'].properties['clouddev']) + +@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageCopy(unittest.TestCase): def test_image_copyout_1x1(self, img_type=dtypes.imagef): it = Tensor.arange(4).cast(img_type((1,1,4))).realize() @@ -40,7 +43,7 @@ class TestImageCopy(unittest.TestCase): assert (it == it2).sum().item() == prod(sz) -@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU") +@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageDType(unittest.TestCase): def test_image_and_back(self): data = Tensor.randn(9*27*4).realize() @@ -138,7 +141,7 @@ class TestImageDType(unittest.TestCase): self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32) self.assertEqual(len(sched), 10) -@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU") +@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageRealization(unittest.TestCase): def test_image_dtype_expand(self): data = Tensor.randn(9*27*4).realize()