remove .global_dims that are for locating GLOBAL [pr] (#11264)

This commit is contained in:
chenyu
2025-07-16 11:19:31 -04:00
committed by GitHub
parent e6c016ddd0
commit 59b52d49d7
2 changed files with 12 additions and 11 deletions

View File

@@ -246,13 +246,13 @@ class Kernel:
# ******************** apply optimizations ********************
def real_axis(self, opt:Opt):
def real_axis(self, op:OptOps, axis:int|None):
try:
if opt.axis is None: return -1
if opt.op is OptOps.UNROLL: return self.unrollable_dims[opt.axis]
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[opt.axis]
check(opt.axis < self.shape_len, "invalid axis")
return opt.axis
if axis is None: return -1
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
check(axis < self.shape_len, "invalid axis")
return axis
except IndexError as e: raise KernelOptError from e
def apply_opt(self, opt:Opt, append_opt:bool=True):
@@ -271,9 +271,9 @@ class Kernel:
self.applied_opts.append(opt)
return
axis = self.real_axis(opt)
axis = self.real_axis(opt.op, opt.axis)
if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
elif opt.arg is not None:
check(isinstance(opt.arg, int), "arg should be int")
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
@@ -293,7 +293,7 @@ class Kernel:
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
# it's disabled for now since it makes BEAM slow for little gain
check(self.opts.has_local, "target does not support local")
check(axis < self.global_dims, "local is for globals")
check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals")
self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
@@ -315,7 +315,8 @@ class Kernel:
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
self.dont_use_locals = True
elif opt.op is OptOps.SWAP:
check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}")
check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}")
permute = list(range(self.shape_len))
permute[axis], permute[amt] = permute[amt], permute[axis]
self.permute(tuple(permute))

View File

@@ -121,7 +121,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
try: ax = lin.real_axis(a)
try: ax = lin.real_axis(a.op, a.axis)
except KernelOptError: continue
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()