mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Enable image tests on cloud if clouddev supports image (#9903)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,10 @@ from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "CLOUD" else Device['CLOUD'].properties['clouddev'])
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
@@ -40,7 +43,7 @@ class TestImageCopy(unittest.TestCase):
|
||||
|
||||
assert (it == it2).sum().item() == prod(sz)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
@@ -138,7 +141,7 @@ class TestImageDType(unittest.TestCase):
|
||||
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
|
||||
self.assertEqual(len(sched), 10)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
|
||||
Reference in New Issue
Block a user