some noop hand_coded_optimizations cleanup [pr] (#11188)

This commit is contained in:
chenyu
2025-07-12 00:09:23 -04:00
committed by GitHub
parent 1ad852a892
commit fdcc25e392

View File

@@ -32,11 +32,10 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
if all(st.shape[k.first_reduce] % sz == 0 or st.shape[k.first_reduce] == 1 for st in k.sts):
try: # may fail due to excessive smem usage
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
try: # may fail due to excessive smem usage
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):
@@ -91,8 +90,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
else: break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
if k.first_reduce < k.first_upcast and (prod(k.full_shape[k.first_upcast:]) <= 4 or \
not any(x!=y for x,y in zip(k.sts[0].shape[k.first_upcast:], k.full_shape[k.first_upcast:]))) and \
if k.first_reduce < k.first_upcast and \
(prod(k.full_shape[k.first_upcast:]) <= 4 or (k.sts[0].shape[k.first_upcast:] == k.full_shape[k.first_upcast:])) and \
(k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64):
if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
@@ -113,7 +112,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# **** local groups ****
if k.opts.has_local:
if NOLOCALS and k.local_dims == 0 and not k.group_for_reduces:
if NOLOCALS:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local