shave more hcopt [pr] (#11213)

start to use AxisType for conditions
This commit is contained in:
chenyu
2025-07-13 12:43:58 -04:00
committed by GitHub
parent 4ef6b46b34
commit 9575cf6c6e

View File

@@ -1,5 +1,5 @@
import itertools
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
from tinygrad.helpers import getenv, DEBUG, all_int, prod, NOLOCALS
from tinygrad.dtype import ImageDType
from tinygrad.uop.ops import Ops, resolve
@@ -72,8 +72,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
for axis, upcast_amount in itertools.product(range(k.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and \
any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in k.sts[buf_index].real_strides()[k.first_upcast:]) \
for buf_index, st in enumerate(k.sts)):
any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in st.real_strides()[k.first_upcast:]) \
for st in k.sts):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if xb_choices:
@@ -83,10 +83,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
upcasted_axis.add(xb_choices[0][2])
else: break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
# if last dim is small(ish) and it's a reduce dim, loop unroll the reduce
if k.first_reduce < k.first_upcast and \
(prod(k.full_shape[k.first_upcast:]) <= 4 or (k.sts[0].shape[k.first_upcast:] == k.full_shape[k.first_upcast:])) and \
(k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64):
(prod(k.full_shape[k.first_upcast:]) <= 4 or (AxisType.UNROLL not in k.axis_types)) and (prod(k.full_shape[k.first_upcast:]) < 64):
if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
# if it's small, upcast a second reduce dimension too