mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
some noop hand_coded_optimizations cleanup [pr] (#11188)
This commit is contained in:
@@ -32,11 +32,10 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[k.first_reduce] % sz == 0 or st.shape[k.first_reduce] == 1 for st in k.sts):
|
||||
try: # may fail due to excessive smem usage
|
||||
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
except KernelOptError: pass
|
||||
try: # may fail due to excessive smem usage
|
||||
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
except KernelOptError: pass
|
||||
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
@@ -91,8 +90,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
else: break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
||||
if k.first_reduce < k.first_upcast and (prod(k.full_shape[k.first_upcast:]) <= 4 or \
|
||||
not any(x!=y for x,y in zip(k.sts[0].shape[k.first_upcast:], k.full_shape[k.first_upcast:]))) and \
|
||||
if k.first_reduce < k.first_upcast and \
|
||||
(prod(k.full_shape[k.first_upcast:]) <= 4 or (k.sts[0].shape[k.first_upcast:] == k.full_shape[k.first_upcast:])) and \
|
||||
(k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64):
|
||||
if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
|
||||
@@ -113,7 +112,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
# **** local groups ****
|
||||
|
||||
if k.opts.has_local:
|
||||
if NOLOCALS and k.local_dims == 0 and not k.group_for_reduces:
|
||||
if NOLOCALS:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
|
||||
Reference in New Issue
Block a user