mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
remove cast before view (#8613)
* remove cast before view * greener * indexing * that passes too * openpilot too * ack --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -1436,6 +1436,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_late_fusion_post_expand(self):
|
||||
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_view(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
@@ -1446,6 +1447,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we might want to reconsider pushing this cast before the shrink
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_after_shrink(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
@@ -1455,6 +1457,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
@@ -1464,6 +1467,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
@@ -1566,7 +1570,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous(self):
|
||||
@@ -1574,7 +1578,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_child(self):
|
||||
@@ -1582,7 +1586,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)+1
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous_child(self):
|
||||
@@ -1590,7 +1594,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_childless_base(self):
|
||||
|
||||
Reference in New Issue
Block a user