remove first_reduce used for locate real_axis [pr] (#11245)

LOCAL goes to the last of (GLOBAL+LOCAL)+1
GROUP goes to right before first REDUCE
This commit is contained in:
chenyu
2025-07-15 09:19:38 -04:00
committed by GitHub
parent 0e2422d216
commit 034e51bd36
2 changed files with 15 additions and 11 deletions

View File

@@ -249,10 +249,12 @@ class Kernel:
# ******************** apply optimizations ********************
def real_axis(self, opt:Opt):
if opt.axis is None: return -1
if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
return opt.axis
try:
if opt.axis is None: return -1
if opt.op is OptOps.UNROLL: return self.unrollable_dims[opt.axis]
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[opt.axis]
return opt.axis
except IndexError as e: raise KernelOptError from e
def apply_opt(self, opt:Opt, append_opt:bool=True):
if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized")
@@ -294,13 +296,13 @@ class Kernel:
# it's disabled for now since it makes BEAM slow for little gain
check(self.opts.has_local, "target does not support local")
check(axis < self.global_dims, "local is for globals")
self.shift_to(axis, amt, AxisType.LOCAL, insert_before=self.first_reduce)
self.shift_to(axis, amt, AxisType.LOCAL, insert_before=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group")
check(not self.tensor_core, "can't group with tensor cores")
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=min(self.axes_of(AxisType.REDUCE)))
elif opt.op is OptOps.UNROLL: # purple
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted")
check(amt <= 32, "don't unroll more than 32")
@@ -362,11 +364,11 @@ class Kernel:
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0]
if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)))
if not (axis < len(axis_choices)): return None
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k

View File

@@ -121,7 +121,9 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
try: ax = lin.real_axis(a)
except KernelOptError: continue
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)