remove Kernel.first_reduce [pr] (#11269)

This commit is contained in:
chenyu
2025-07-16 18:30:14 -04:00
committed by GitHub
parent 522dc72f08
commit 60ffe00172
2 changed files with 8 additions and 10 deletions

View File

@@ -16,11 +16,12 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \
not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
for global_idx in k.axes_of(AxisType.GLOBAL):
if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3:
print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
@@ -41,7 +42,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
elif axis in k.unrollable_dims:
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, 4))
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
# no more opt if we are grouping
if k.group_for_reduces: return k.applied_opts
@@ -82,14 +83,14 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
# if it's small, upcast a second reduce dimension too
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
else:
for splits in [4]:
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, splits))
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
break
# if nothing at all is upcasted and it's easy to, do an upcast

View File

@@ -114,9 +114,6 @@ class Kernel:
return ret
@property
def first_reduce(self) -> int: return self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)[0]
@property
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
@property