diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 6d4c2ca0dc..ab4e133942 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -246,13 +246,13 @@ class Kernel: # ******************** apply optimizations ******************** - def real_axis(self, opt:Opt): + def real_axis(self, op:OptOps, axis:int|None): try: - if opt.axis is None: return -1 - if opt.op is OptOps.UNROLL: return self.unrollable_dims[opt.axis] - if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[opt.axis] - check(opt.axis < self.shape_len, "invalid axis") - return opt.axis + if axis is None: return -1 + if op is OptOps.UNROLL: return self.unrollable_dims[axis] + if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] + check(axis < self.shape_len, "invalid axis") + return axis except IndexError as e: raise KernelOptError from e def apply_opt(self, opt:Opt, append_opt:bool=True): @@ -271,9 +271,9 @@ class Kernel: self.applied_opts.append(opt) return - axis = self.real_axis(opt) + axis = self.real_axis(opt.op, opt.axis) - if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs + if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs elif opt.arg is not None: check(isinstance(opt.arg, int), "arg should be int") amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis] @@ -293,7 +293,7 @@ class Kernel: # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache) # it's disabled for now since it makes BEAM slow for little gain check(self.opts.has_local, "target does not support local") - check(axis < self.global_dims, "local is for globals") + check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals") self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1) elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") @@ -315,7 +315,8 @@ class Kernel: check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals") self.dont_use_locals = True elif opt.op is OptOps.SWAP: - check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}") + check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}") + check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}") permute = list(range(self.shape_len)) permute[axis], permute[amt] = permute[amt], permute[axis] self.permute(tuple(permute)) diff --git a/tinygrad/opt/search.py b/tinygrad/opt/search.py index 41d32ca6e2..6d854d734e 100644 --- a/tinygrad/opt/search.py +++ b/tinygrad/opt/search.py @@ -121,7 +121,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: - try: ax = lin.real_axis(a) + try: ax = lin.real_axis(a.op, a.axis) except KernelOptError: continue if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy()