mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
grouper: merge views in fuse elementwise (#10325)
* grouper: merge views in fuse elementwise * with gradient api
This commit is contained in:
@@ -1453,6 +1453,15 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(d, 2))
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
|
||||
|
||||
def test_fuse_arange_pad_replicate_mode(self):
|
||||
x = Tensor.empty(3,3,3,3, requires_grad=True)
|
||||
y = x.pad((-1,2,2,-1), mode="replicate")
|
||||
dx = y.sum().gradient(x)[0]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
sched = check_schedule(dx, 3)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
|
||||
|
||||
# TODO like openpilot with imagef
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_base_change_expand_expand(self):
|
||||
|
||||
@@ -422,7 +422,7 @@ pm_fuse = PatternMatcher([
|
||||
|
||||
# FUSE elementwise.
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
|
||||
lambda alu, view: alu.replace(src=tuple(x.view(view.arg).fuse() for x in alu.src))),
|
||||
lambda alu, view: alu.replace(src=tuple(x.view(x.arg+view.arg if x.op is Ops.VIEW else view.arg).fuse() for x in alu.src))),
|
||||
|
||||
# push FUSE through to srcs
|
||||
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
||||
|
||||
Reference in New Issue
Block a user