diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 443d22f0e9..b22f63d459 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -102,7 +102,8 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: - acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions + acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) + kernel_actions = actions.copy() if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first for i, action in enumerate(kernel_actions): @@ -112,7 +113,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy() try: lin2.apply_opt(a)