setitem with arange fusion 1 (#5898)

This commit is contained in:
qazal
2024-08-04 21:09:21 +08:00
committed by GitHub
parent 59315ffc78
commit 4c5ef2cc4f
2 changed files with 15 additions and 3 deletions

View File

@@ -1385,17 +1385,29 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
def test_arange_childless(self):
def test_arange_childless_base(self):
a = Tensor.arange(4)
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4))
def test_arange_group_childless(self):
def test_arange_childless_view(self):
a = Tensor.arange(4).reshape(2, 2)
a[0] = 4
np.testing.assert_equal(a.numpy(), [[4, 4], [2, 3]])
def test_arange_group_childless_base(self):
Tensor.manual_seed(0)
x = Tensor.randint(4)
a = Tensor.arange(4)+x
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
def test_arange_group_childless_view(self):
Tensor.manual_seed(0)
x = Tensor.ones(4).contiguous().realize()
a = Tensor.arange(4)+x
a[0] = 6
np.testing.assert_equal(a.numpy(), [6., 2., 3., 4.])
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -301,7 +301,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
for r in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}")
if any(tr.forced_realize or tr in outs for tr in group): continue
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
for tr in group: del realizes[tr]