mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
search: fix actions space checks to ignore TC axis and amt (#4360)
* search: fix actions space checks to ignore TC axis and amt * add test for number of actions in get_linearizer_actions
This commit is contained in:
@@ -274,9 +274,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
prg.exec(real_bufs)
|
||||
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
|
||||
# ensure the results for each choice of axis matches
|
||||
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
|
||||
|
||||
# check that get_linearizer_actions produces all 9 options
|
||||
from tinygrad.features.search import get_linearizer_actions
|
||||
tc_actions = [k for i, k in get_linearizer_actions(Linearizer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
|
||||
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
|
||||
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
|
||||
@@ -85,8 +85,8 @@ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
|
||||
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
||||
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 256)
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis is not None and a.axis >= lin.shape_len: continue
|
||||
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
if a.axis is not None and a.op is not OptOps.TC and a.axis >= lin.shape_len: continue
|
||||
if a.axis is not None and a.op is not OptOps.TC and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
|
||||
Reference in New Issue
Block a user