From c12bcabb071c774074b58342457afd799ab59059 Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Tue, 30 Apr 2024 11:02:22 -0700 Subject: [PATCH] search: fix actions space checks to ignore TC axis and amt (#4360) * search: fix actions space checks to ignore TC axis and amt * add test for number of actions in get_linearizer_actions --- test/test_linearizer.py | 7 +++++++ tinygrad/features/search.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 43909eff00..2ee4c54f68 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -274,9 +274,16 @@ class TestLinearizer(unittest.TestCase): prg.exec(real_bufs) result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) + # ensure the results for each choice of axis matches if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15) + # check that get_linearizer_actions produces all 9 options + from tinygrad.features.search import get_linearizer_actions + tc_actions = [k for i, k in get_linearizer_actions(Linearizer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] + assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}" + + def test_limit_dims_to_max_5d_global(self): t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index daf75b5f59..e74e3e9877 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -85,8 +85,8 @@ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]: def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 256) for i,a in enumerate(actions): - if a.axis is not None and a.axis >= lin.shape_len: continue - if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue + if a.axis is not None and a.op is not OptOps.TC and a.axis >= lin.shape_len: continue + if a.axis is not None and a.op is not OptOps.TC and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue lin2 = lin.copy() try: lin2.apply_opt(a)