diff --git a/test/test_schedule.py b/test/test_schedule.py index 7238bea5d1..4f779e0f86 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1500,6 +1500,7 @@ class TestSchedule(unittest.TestCase): y = x.pad((-1,2,2,-1), mode="replicate") dx = y.sum().gradient(x)[0] sched = check_schedule(dx, 1) + self.assertEqual(sched[0].ast.op_in_backward_slice_with_self(Ops.REDUCE), False) run_schedule(sched) np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index f8905c6166..a8c7bcd8e3 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -222,9 +222,9 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([ and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None), # remove noop buffers. if we look at the next index we can remove even more of these (UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize), - # dont bufferize an arange - (UPat.any((r:=UPat(dtype=dtypes.index).cast()).named("src"), r.eq(UPat()).named("src")).f(Ops.BUFFERIZE, - allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize), + # dont bufferize arange like expressions + (UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda src,buf,idx: + remove_bufferize(src, buf, idx) if not src.op_in_backward_slice_with_self(Ops.INDEX, Ops.REDUCE) else None), # no buffers for const (UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)), # indexing a const is a const