search: fix actions space checks to ignore TC axis and amt (#4360)

* search: fix actions space checks to ignore TC axis and amt

* add test for number of actions in get_linearizer_actions
This commit is contained in:
Francis Lam
2024-04-30 11:02:22 -07:00
committed by GitHub
parent fdc8fabae5
commit c12bcabb07
2 changed files with 9 additions and 2 deletions

View File

@@ -274,9 +274,16 @@ class TestLinearizer(unittest.TestCase):
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
# ensure the results for each choice of axis matches
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
# check that get_linearizer_actions produces all 9 options
from tinygrad.features.search import get_linearizer_actions
tc_actions = [k for i, k in get_linearizer_actions(Linearizer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
def test_limit_dims_to_max_5d_global(self):
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]

View File

@@ -85,8 +85,8 @@ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 256)
for i,a in enumerate(actions):
if a.axis is not None and a.axis >= lin.shape_len: continue
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
if a.axis is not None and a.op is not OptOps.TC and a.axis >= lin.shape_len: continue
if a.axis is not None and a.op is not OptOps.TC and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)