From 4c5ef2cc4f4d67665e5b4f540001dc08922edc55 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 4 Aug 2024 21:09:21 +0800 Subject: [PATCH] setitem with arange fusion 1 (#5898) --- test/test_schedule.py | 16 ++++++++++++++-- tinygrad/engine/schedule.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 5ac76daf16..aaf2241541 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1385,17 +1385,29 @@ class TestIndexing(unittest.TestCase): self.check_schedule(out, 2) np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum()) - def test_arange_childless(self): + def test_arange_childless_base(self): a = Tensor.arange(4) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), np.arange(4)) - def test_arange_group_childless(self): + def test_arange_childless_view(self): + a = Tensor.arange(4).reshape(2, 2) + a[0] = 4 + np.testing.assert_equal(a.numpy(), [[4, 4], [2, 3]]) + + def test_arange_group_childless_base(self): Tensor.manual_seed(0) x = Tensor.randint(4) a = Tensor.arange(4)+x self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy()) + def test_arange_group_childless_view(self): + Tensor.manual_seed(0) + x = Tensor.ones(4).contiguous().realize() + a = Tensor.arange(4)+x + a[0] = 6 + np.testing.assert_equal(a.numpy(), [6., 2., 3., 4.]) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 282ed2bd7c..44b1125cdc 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -301,7 +301,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): for r in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}") - if any(tr.forced_realize or tr in outs for tr in group): continue + if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue if DEBUG_ARANGE: print(colored(f"folding {r}", "green")) for tr in group: del realizes[tr]