hotfix bug in get_kernel_actions after TC_SEARCH_OVER_SHAPE was introduced (#8904)

* hotfix search bug

* copy actions
This commit is contained in:
Ignacio Sica
2025-02-05 15:10:05 -03:00
committed by GitHub
parent 15f94ac964
commit 0f6109ec00

View File

@@ -102,7 +102,8 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
# get dictionary of all possible actions
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
kernel_actions = actions.copy()
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
for i, action in enumerate(kernel_actions):
@@ -112,7 +113,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)