mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Enable image tests on cloud if clouddev supports image (#9903)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -462,8 +462,8 @@ jobs:
|
||||
- name: Run CLOUD=1 Test
|
||||
run: |
|
||||
CLOUDDEV=CPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
CLOUDDEV=GPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=GPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py
|
||||
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
|
||||
- name: Test Optimization Helpers
|
||||
run: PYTHONPATH="." DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
- name: Test Action Space
|
||||
|
||||
@@ -6,7 +6,10 @@ from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "CLOUD" else Device['CLOUD'].properties['clouddev'])
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
@@ -40,7 +43,7 @@ class TestImageCopy(unittest.TestCase):
|
||||
|
||||
assert (it == it2).sum().item() == prod(sz)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
@@ -138,7 +141,7 @@ class TestImageDType(unittest.TestCase):
|
||||
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
|
||||
self.assertEqual(len(sched), 10)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
|
||||
Reference in New Issue
Block a user