mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -79,8 +79,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
else: break
|
||||
|
||||
# if last reduce dim is small(ish), loop unroll the reduce
|
||||
upcast_size = prod(s for s,t in zip(k.full_shape, k.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL))
|
||||
if k.unrollable_dims and (upcast_size <= 4 or (AxisType.UNROLL not in k.axis_types)) and (upcast_size < 64):
|
||||
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
|
||||
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
@@ -105,8 +105,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) \
|
||||
for axis,t in enumerate(k.axis_types) if t in (AxisType.GLOBAL, AxisType.LOOP)]
|
||||
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)]
|
||||
to_local: list[tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.device import Device
|
||||
from tinygrad.opt.tc import TensorCore
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.kernelize.kernelize import view_left
|
||||
@@ -115,10 +115,7 @@ class Kernel:
|
||||
return ret
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int:
|
||||
for i in range(self.first_upcast):
|
||||
if self.axis_types[i] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): return i
|
||||
return self.first_upcast
|
||||
def first_reduce(self) -> int: return next(iter(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)), self.first_upcast)
|
||||
@property
|
||||
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
||||
|
||||
@@ -132,22 +129,23 @@ class Kernel:
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
|
||||
@property
|
||||
def global_dims(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GLOBAL])
|
||||
def global_dims(self) -> int: return len(self.axes_of(AxisType.GLOBAL))
|
||||
@property
|
||||
def local_dims(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.LOCAL])
|
||||
def local_dims(self) -> int: return len(self.axes_of(AxisType.LOCAL))
|
||||
@property
|
||||
def upcasted(self) -> int: return sum([1 for x in self.axis_types if x in {AxisType.UPCAST, AxisType.UNROLL}])
|
||||
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
@property
|
||||
def group_for_reduces(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GROUP_REDUCE])
|
||||
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
||||
|
||||
# heuristic helpers
|
||||
@property
|
||||
def upcastable_dims(self) -> list[int]: return [i for i,(a,s) in enumerate(zip(self.axis_types, self.full_shape)) \
|
||||
if a in (AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) and isinstance(s, int) and s > 1]
|
||||
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
@property
|
||||
def unrollable_dims(self) -> list[int]: return [i for i,(a,s) in enumerate(zip(self.axis_types, self.full_shape)) \
|
||||
if a in (AxisType.REDUCE, AxisType.GROUP_REDUCE) and isinstance(s, int) and s > 1]
|
||||
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
|
||||
# ******************** colors and names ********************
|
||||
|
||||
@@ -286,8 +284,8 @@ class Kernel:
|
||||
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
||||
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
acc_sz = self.reduceop.dtype.itemsize
|
||||
upcast_sz = prod([s for s,t in zip(self.full_shape, self.axis_types) if t is AxisType.UPCAST])
|
||||
local_sz = prod([s for s,t in zip(self.full_shape, self.axis_types) if t is AxisType.LOCAL])
|
||||
upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)])
|
||||
local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)])
|
||||
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
||||
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
||||
|
||||
@@ -465,9 +463,9 @@ class Kernel:
|
||||
return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts)))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] in {AxisType.REDUCE, AxisType.UNROLL} and
|
||||
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if
|
||||
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
||||
grouped_axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] is AxisType.GROUP_REDUCE and
|
||||
grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if
|
||||
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
||||
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
|
||||
Reference in New Issue
Block a user