remove cast before view (#8613)

* remove cast before view

* greener

* indexing

* that passes too

* openpilot too

* ack

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2025-01-14 12:04:58 -08:00
committed by GitHub
parent 393eec3201
commit bfbe81df71
9 changed files with 22 additions and 25 deletions

View File

@@ -1436,6 +1436,7 @@ class TestSchedule(unittest.TestCase):
def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_view(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
@@ -1446,6 +1447,7 @@ class TestSchedule(unittest.TestCase):
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
# NOTE: we might want to reconsider pushing this cast before the shrink
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_after_shrink(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
@@ -1455,6 +1457,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
self.assertListEqual(realized_view.tolist(), [[0, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
@@ -1464,6 +1467,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_const(self):
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dtypes.float32)
@@ -1566,7 +1570,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous(self):
@@ -1574,7 +1578,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_child(self):
@@ -1582,7 +1586,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous_child(self):
@@ -1590,7 +1594,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_childless_base(self):