group heuristic

This commit is contained in:
George Hotz
2025-12-14 07:28:38 -05:00
parent 55845f7de7
commit 4ec83a47eb

View File

@@ -48,6 +48,17 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# make a copy so it does not mutate the input
k = k.copy()
# when TC fails for kernels with large reductions, try GROUP to parallelize the reduction
if k.ren.has_local and k.ren.has_shared and k.reduceop is not None:
reduce_axes = k.axes_of(AxisType.REDUCE)
if reduce_axes and resolve(k.full_shape[reduce_axes[0]] >= 256, False):
for amt in [16, 8]:
if k.full_shape[reduce_axes[0]] % amt == 0:
try:
k.apply_opt(Opt(OptOps.GROUP, 0, amt))
break
except KernelOptError: pass
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast
for buf_index,buf in enumerate(k.bufs):
if isinstance(buf.src[0].dtype, ImageDType):