diff --git a/test/test_schedule.py b/test/test_schedule.py index 027c2c8999..a6205952df 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1453,6 +1453,15 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(d, 2)) np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2) + def test_fuse_arange_pad_replicate_mode(self): + x = Tensor.empty(3,3,3,3, requires_grad=True) + y = x.pad((-1,2,2,-1), mode="replicate") + dx = y.sum().gradient(x)[0] + with Context(FUSE_ARANGE=1): + sched = check_schedule(dx, 3) + run_schedule(sched) + np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3) + # TODO like openpilot with imagef @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_base_change_expand_expand(self): diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 3abb762059..09f3a90de8 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -422,7 +422,7 @@ pm_fuse = PatternMatcher([ # FUSE elementwise. (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(), - lambda alu, view: alu.replace(src=tuple(x.view(view.arg).fuse() for x in alu.src))), + lambda alu, view: alu.replace(src=tuple(x.view(x.arg+view.arg if x.op is Ops.VIEW else view.arg).fuse() for x in alu.src))), # push FUSE through to srcs (UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),