mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
image buffer realization spec [pr] (#8420)
* image buffer realization spec [pr] * redo the spec * work
This commit is contained in:
@@ -134,5 +134,39 @@ class TestImageDType(unittest.TestCase):
|
||||
print(lst)
|
||||
assert not np.any(np.isnan(lst))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize()
|
||||
self.assertEqual(it_expanded.dtype, dtypes.float32)
|
||||
|
||||
def test_image_dtype_expand_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4))
|
||||
it2 = it_expanded.sum(3).realize()
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
def test_image_alu_children(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous()
|
||||
alu1 = it_expanded+1
|
||||
alu2 = it_expanded.sum(3)
|
||||
it_expanded.realize()
|
||||
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4)))
|
||||
alu1.realize()
|
||||
self.assertEqual(alu1.dtype, dtypes.float32)
|
||||
# alu2 is back in image because it fits the dtype again
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
alu2.realize()
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user