From fdcc25e39244be38f24f18d73c0dc4e111d44f7d Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 12 Jul 2025 00:09:23 -0400 Subject: [PATCH] some noop hand_coded_optimizations cleanup [pr] (#11188) --- tinygrad/opt/heuristic.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tinygrad/opt/heuristic.py b/tinygrad/opt/heuristic.py index 5940c86c21..73a8611531 100644 --- a/tinygrad/opt/heuristic.py +++ b/tinygrad/opt/heuristic.py @@ -32,11 +32,10 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]): - if all(st.shape[k.first_reduce] % sz == 0 or st.shape[k.first_reduce] == 1 for st in k.sts): - try: # may fail due to excessive smem usage - k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) - break - except KernelOptError: pass + try: # may fail due to excessive smem usage + k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) + break + except KernelOptError: pass # upcast float4 images for buf_index,buf in enumerate(k.bufs): @@ -91,8 +90,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: else: break # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. - if k.first_reduce < k.first_upcast and (prod(k.full_shape[k.first_upcast:]) <= 4 or \ - not any(x!=y for x,y in zip(k.sts[0].shape[k.first_upcast:], k.full_shape[k.first_upcast:]))) and \ + if k.first_reduce < k.first_upcast and \ + (prod(k.full_shape[k.first_upcast:]) <= 4 or (k.sts[0].shape[k.first_upcast:] == k.full_shape[k.first_upcast:])) and \ (k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64): if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0)) @@ -113,7 +112,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: # **** local groups **** if k.opts.has_local: - if NOLOCALS and k.local_dims == 0 and not k.group_for_reduces: + if NOLOCALS: k.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local