Kernel.axes_of helper [pr] (#11243)

look up dim based on AxisType
This commit is contained in:
chenyu
2025-07-14 22:17:43 -04:00
committed by GitHub
parent 968f6b2a2e
commit 0e2422d216
2 changed files with 18 additions and 21 deletions

View File

@@ -79,8 +79,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
else: break
# if last reduce dim is small(ish), loop unroll the reduce
upcast_size = prod(s for s,t in zip(k.full_shape, k.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL))
if k.unrollable_dims and (upcast_size <= 4 or (AxisType.UNROLL not in k.axis_types)) and (upcast_size < 64):
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
@@ -105,8 +105,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) \
for axis,t in enumerate(k.axis_types) if t in (AxisType.GLOBAL, AxisType.LOOP)]
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)]
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)

View File

@@ -11,7 +11,7 @@ from tinygrad.device import Device
from tinygrad.opt.tc import TensorCore
from tinygrad.renderer import Renderer
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.kernelize.kernelize import view_left
@@ -115,10 +115,7 @@ class Kernel:
return ret
@property
def first_reduce(self) -> int:
for i in range(self.first_upcast):
if self.axis_types[i] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): return i
return self.first_upcast
def first_reduce(self) -> int: return next(iter(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)), self.first_upcast)
@property
def first_upcast(self) -> int: return self.shape_len-self.upcasted
@@ -132,22 +129,23 @@ class Kernel:
@property
def shape_len(self) -> int: return len(self.sts[0].shape)
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
@property
def global_dims(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GLOBAL])
def global_dims(self) -> int: return len(self.axes_of(AxisType.GLOBAL))
@property
def local_dims(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.LOCAL])
def local_dims(self) -> int: return len(self.axes_of(AxisType.LOCAL))
@property
def upcasted(self) -> int: return sum([1 for x in self.axis_types if x in {AxisType.UPCAST, AxisType.UNROLL}])
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
@property
def group_for_reduces(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GROUP_REDUCE])
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
# heuristic helpers
@property
def upcastable_dims(self) -> list[int]: return [i for i,(a,s) in enumerate(zip(self.axis_types, self.full_shape)) \
if a in (AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) and isinstance(s, int) and s > 1]
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
@property
def unrollable_dims(self) -> list[int]: return [i for i,(a,s) in enumerate(zip(self.axis_types, self.full_shape)) \
if a in (AxisType.REDUCE, AxisType.GROUP_REDUCE) and isinstance(s, int) and s > 1]
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
# ******************** colors and names ********************
@@ -286,8 +284,8 @@ class Kernel:
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
acc_sz = self.reduceop.dtype.itemsize
upcast_sz = prod([s for s,t in zip(self.full_shape, self.axis_types) if t is AxisType.UPCAST])
local_sz = prod([s for s,t in zip(self.full_shape, self.axis_types) if t is AxisType.LOCAL])
upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)])
local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)])
smem_sz = amt*acc_sz*upcast_sz*local_sz
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
@@ -465,9 +463,9 @@ class Kernel:
return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts)))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] in {AxisType.REDUCE, AxisType.UNROLL} and
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
grouped_axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] is AxisType.GROUP_REDUCE and
grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
# get reduce/upcast axes for the tensor cores