mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
TestImageCopy valid images
This commit is contained in:
@@ -11,25 +11,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
def test_image_copyout_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4))
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half")
|
||||
def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh)
|
||||
def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_numpy_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(4))
|
||||
def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh)
|
||||
def test_image_numpy_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(32))
|
||||
def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_copyout_2x3(self):
|
||||
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
|
||||
def test_image_copyout_2x4(self):
|
||||
it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4))
|
||||
|
||||
def test_image_roundtrip(self):
|
||||
sz = (4,2,4)
|
||||
|
||||
Reference in New Issue
Block a user