search: fix counting of upcasts to ignore TC upcasts (#4045)

TC upcasts don't impact the size or complexity of the kernel code
This commit is contained in:
Francis Lam
2024-04-02 16:52:05 -07:00
committed by GitHub
parent ccf3c16d6a
commit 88dcdae485

View File

@@ -87,11 +87,11 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
lin2 = lin.copy()
try:
lin2.apply_opt(a)
up, lcl = 1, 1
up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
for s,c in zip(lin2.full_shape, lin2.colors()):
if c in {"magenta", "yellow"}: up *= s
elif c in {"cyan", "green", "white"}: lcl *= s
if up > max_up or lcl > max_lcl: continue
if up//tc_up > max_up or lcl > max_lcl: continue
acted_lins[i+1] = lin2
except KernelOptError: pass
return acted_lins