late fusion tests, early merge view GroupOp.Buffer [pr] (#7577)

* test_late_fusion_double_transpose

* early merge view buffer ops
This commit is contained in:
qazal
2024-11-07 14:04:57 +02:00
committed by GitHub
parent f0fc34e594
commit 1f5ea1e412
2 changed files with 10 additions and 2 deletions

View File

@@ -1360,6 +1360,12 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
def test_late_fusion_double_transpose(self):
with Context(DEBUG=0): a = Tensor.randn(32, 16, 1).realize()
compare = (a.expand(32, 16, 16).sum((2,), keepdim=True).T+2).T.contiguous()
run_schedule(check_schedule(compare, 1))
np.testing.assert_allclose(compare.numpy(), (np.broadcast_to(a.numpy(), (32, 16, 16)).sum(axis=2, keepdims=True).T+2).T, atol=1e-4, rtol=1e-4)
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):

View File

@@ -148,9 +148,11 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view before ALU
(UPat(Ops.VIEW, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])
# push VIEW to stores