mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
spec for CAST_BEFORE_VIEW=1 [pr] (#8512)
This commit is contained in:
@@ -1437,6 +1437,42 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_late_fusion_post_expand(self):
|
||||
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
|
||||
|
||||
def test_cast_padded_view(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
casted_view.realize()
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 4)
|
||||
realized_view = casted_view.contiguous().realize()
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 8)
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we might want to reconsider pushing this cast before the shrink
|
||||
def test_cast_after_shrink(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
casted_view.realize()
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 4)
|
||||
realized_view = casted_view.contiguous().realize()
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
run_schedule(check_schedule(casted_view, 0))
|
||||
self.assertIsNone(casted_view.lazydata.base.realized)
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
run_schedule(check_schedule(casted_view, 0))
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[0], [1], [0]])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
|
||||
Reference in New Issue
Block a user