mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
late fusion tests, early merge view GroupOp.Buffer [pr] (#7577)
* test_late_fusion_double_transpose * early merge view buffer ops
This commit is contained in:
@@ -1360,6 +1360,12 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 4))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_late_fusion_double_transpose(self):
|
||||
with Context(DEBUG=0): a = Tensor.randn(32, 16, 1).realize()
|
||||
compare = (a.expand(32, 16, 16).sum((2,), keepdim=True).T+2).T.contiguous()
|
||||
run_schedule(check_schedule(compare, 1))
|
||||
np.testing.assert_allclose(compare.numpy(), (np.broadcast_to(a.numpy(), (32, 16, 16)).sum(axis=2, keepdims=True).T+2).T, atol=1e-4, rtol=1e-4)
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
|
||||
@@ -148,9 +148,11 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
|
||||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before ALU
|
||||
(UPat(Ops.VIEW, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Buffer), name="e"),), name="v"),
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
# push VIEW to stores
|
||||
|
||||
Reference in New Issue
Block a user