diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 2ed147a4bf..faad7bb9cb 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -249,10 +249,12 @@ class Kernel: # ******************** apply optimizations ******************** def real_axis(self, opt:Opt): - if opt.axis is None: return -1 - if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis - if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis - return opt.axis + try: + if opt.axis is None: return -1 + if opt.op is OptOps.UNROLL: return self.unrollable_dims[opt.axis] + if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[opt.axis] + return opt.axis + except IndexError as e: raise KernelOptError from e def apply_opt(self, opt:Opt, append_opt:bool=True): if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized") @@ -294,13 +296,13 @@ class Kernel: # it's disabled for now since it makes BEAM slow for little gain check(self.opts.has_local, "target does not support local") check(axis < self.global_dims, "local is for globals") - self.shift_to(axis, amt, AxisType.LOCAL, insert_before=self.first_reduce) + self.shift_to(axis, amt, AxisType.LOCAL, insert_before=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1) elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group") check(not self.tensor_core, "can't group with tensor cores") check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces") - self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) + self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=min(self.axes_of(AxisType.REDUCE))) elif opt.op is OptOps.UNROLL: # purple check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted") check(amt <= 32, "don't unroll more than 32") @@ -362,11 +364,11 @@ class Kernel: if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0] - if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None + axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0] + axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0] + if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None - axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len))) + axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE))) if not (axis < len(axis_choices)): return None s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k diff --git a/tinygrad/opt/search.py b/tinygrad/opt/search.py index d3bb084471..41d32ca6e2 100644 --- a/tinygrad/opt/search.py +++ b/tinygrad/opt/search.py @@ -121,7 +121,9 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue + try: ax = lin.real_axis(a) + except KernelOptError: continue + if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy() try: lin2.apply_opt(a)