From 0f6109ec007518167d0be64f46d09fe4103d6cff Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 15:10:05 -0300 Subject: [PATCH] hotfix bug in `get_kernel_actions` after `TC_SEARCH_OVER_SHAPE` was introduced (#8904) * hotfix search bug * copy actions --- tinygrad/engine/search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 443d22f0e9..b22f63d459 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -102,7 +102,8 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: - acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions + acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) + kernel_actions = actions.copy() if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first for i, action in enumerate(kernel_actions): @@ -112,7 +113,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy() try: lin2.apply_opt(a)