mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove first_reduce used for locate real_axis [pr] (#11245)
LOCAL goes to the last of (GLOBAL+LOCAL)+1 GROUP goes to right before first REDUCE
This commit is contained in:
@@ -249,10 +249,12 @@ class Kernel:
|
||||
# ******************** apply optimizations ********************
|
||||
|
||||
def real_axis(self, opt:Opt):
|
||||
if opt.axis is None: return -1
|
||||
if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
|
||||
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
|
||||
return opt.axis
|
||||
try:
|
||||
if opt.axis is None: return -1
|
||||
if opt.op is OptOps.UNROLL: return self.unrollable_dims[opt.axis]
|
||||
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[opt.axis]
|
||||
return opt.axis
|
||||
except IndexError as e: raise KernelOptError from e
|
||||
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
||||
if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized")
|
||||
@@ -294,13 +296,13 @@ class Kernel:
|
||||
# it's disabled for now since it makes BEAM slow for little gain
|
||||
check(self.opts.has_local, "target does not support local")
|
||||
check(axis < self.global_dims, "local is for globals")
|
||||
self.shift_to(axis, amt, AxisType.LOCAL, insert_before=self.first_reduce)
|
||||
self.shift_to(axis, amt, AxisType.LOCAL, insert_before=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
|
||||
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
||||
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
||||
check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group")
|
||||
check(not self.tensor_core, "can't group with tensor cores")
|
||||
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
||||
self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=min(self.axes_of(AxisType.REDUCE)))
|
||||
elif opt.op is OptOps.UNROLL: # purple
|
||||
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted")
|
||||
check(amt <= 32, "don't unroll more than 32")
|
||||
@@ -362,11 +364,11 @@ class Kernel:
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
||||
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0]
|
||||
if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)))
|
||||
if not (axis < len(axis_choices)): return None
|
||||
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
|
||||
@@ -121,7 +121,9 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
||||
try: ax = lin.real_axis(a)
|
||||
except KernelOptError: continue
|
||||
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
|
||||
Reference in New Issue
Block a user