grouper: merge views in fuse elementwise (#10325)

* grouper: merge views in fuse elementwise

* with gradient api
This commit is contained in:
qazal
2025-05-15 13:17:09 +03:00
committed by GitHub
parent 89d8d5b25e
commit 0a45cd0cbe
2 changed files with 10 additions and 1 deletions

View File

@@ -1453,6 +1453,15 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(d, 2))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
def test_fuse_arange_pad_replicate_mode(self):
x = Tensor.empty(3,3,3,3, requires_grad=True)
y = x.pad((-1,2,2,-1), mode="replicate")
dx = y.sum().gradient(x)[0]
with Context(FUSE_ARANGE=1):
sched = check_schedule(dx, 3)
run_schedule(sched)
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
# TODO like openpilot with imagef
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_base_change_expand_expand(self):

View File

@@ -422,7 +422,7 @@ pm_fuse = PatternMatcher([
# FUSE elementwise.
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
lambda alu, view: alu.replace(src=tuple(x.view(view.arg).fuse() for x in alu.src))),
lambda alu, view: alu.replace(src=tuple(x.view(x.arg+view.arg if x.op is Ops.VIEW else view.arg).fuse() for x in alu.src))),
# push FUSE through to srcs
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),