Files
tinygrad/tinygrad/codegen/opt/postrange.py
George Hotz 6d3385c284 print special ops in postrange (#13318)
* print special ops in postrange

* fix on OSX
2025-11-17 14:43:23 -08:00

350 lines
20 KiB
Python

from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
class Scheduler:
def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
@property
def rngs(self):
# always in order by axistype
return sorted([u for u in self.ast.backward_slice if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
@property
def shape_len(self): return len(self.rngs)
@property
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
@property
def axis_types(self): return [x.arg[-1] for x in self.rngs]
@property
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
def shape_str(self) -> list[str]:
ret: list[str] = []
cnt: dict[AxisType, int] = {}
for x in self.axis_types:
cnt[x] = (cnt[x] + 1) if x in cnt else 0
ret.append(f"{axis_letters[x]}{cnt[x]}")
return ret
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
def copy(self):
ret = Scheduler(self.ast, self.ren)
ret.dont_use_locals = self.dont_use_locals
ret.applied_opts = self.applied_opts[:]
return ret
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
def get_optimized_ast(self, name_override:str|None=None):
if name_override is not None: name = name_override
else:
k_type = "r" if self.reduceop is not None else "E"
special_uops = sorted([x for x in self.ast.toposort() if x.op is Ops.SPECIAL], key=lambda x: x.arg)
special_ops = [colored(str(x.vmax+1), "blue" if x.arg[0] == "g" else "cyan") for x in special_uops]
name = k_type + colored('_', 'BLACK').join(['']+special_ops+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
name += colored(num, 'BLACK')
self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _output_rngs(self) -> list[UOp]:
return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
def _globalizable_rngs(self) -> list[UOp]:
ret = [r for r in self._output_rngs() if r.arg[-1] == AxisType.LOOP]
# exclude any output ranges from global that don't appear in all BUFFERIZE
for x in self.ast.toposort():
if x.op is Ops.BUFFERIZE:
ret = [r for r in ret if r in x.ranges]
return ret
def convert_loop_to_global(self):
if not self.ren.has_local: return None
globalizible_rngs = self._globalizable_rngs()
rng = [x.replace(arg=x.arg[0:-1]+(AxisType.GLOBAL,)) if x in globalizible_rngs else x for x in self.rngs]
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]:
output_rngs = self._output_rngs()
globalizible_rngs = self._globalizable_rngs()
ret = []
for x,r in zip(self.axis_types, self.rngs):
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
elif r not in output_rngs and x == AxisType.LOOP: ret.append("BLACK")
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("white")
else: ret.append(axis_colors[x])
return ret
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
if (old_sz:=rng.src[0].divides(amount)) is None:
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}")
return replaced_rng, new_rng
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]
def upcast_size(self) -> int: return prod(self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
# copied from kernel.py
@property
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
@property
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
def real_axis(self, op:OptOps, axis:int|None):
try:
if axis is None or op is OptOps.TC: return -1
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
return axis
except IndexError as e: raise KernelOptError from e
def apply_opt(self, opt:Opt, append_opt:bool=True):
if opt.op is OptOps.NOLOCALS:
check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
if append_opt: self.applied_opts.append(opt)
self.dont_use_locals = True
return
if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
check(self.ren.has_local, "locals needed for opt")
rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
opt_to_at = {
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
ret = None
if opt.op in opt_to_at:
amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
# copied from kernel.py. prevents METAL compiler hangs
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
check(smem_sz <= self.ren.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.ren.shared_max}")
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP}):
# We currently dont support a group within another rudece, TODO: fix if-contexts
reduce = [u for u in self.ast.backward_slice if u.op is Ops.REDUCE and rng in merge_dicts([r.ranges for r in u.src[1:]])][0]
check(not any(u.arg[-1] in (AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE) for u in reduce.ranges),
"cannot have a GROUP_REDUCE inside another reduce")
if opt.op is OptOps.UNROLL:
check(amt <= 32, "don't unroll more than 32")
check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
if opt.op is OptOps.UPCAST:
check((self.ren is not None and self.ren.device == "DSP") or amt <= 16, "don't upcast more than 16")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
if opt.op is OptOps.LOCAL:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
if opt.op is OptOps.THREAD:
check(self.ren is not None and self.ren.has_threads, "target does not support threads")
check(self.ren is not None and self.ren.global_max is not None and amt <= self.ren.global_max[0], "too many threads")
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
except ValueError as e: raise KernelOptError(str(e))
check(ret is not None, "no tensor core available")
elif opt.op is OptOps.PADTO:
check(rng.src[0].op is Ops.CONST, "only pad const axes")
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}")
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)
replaces = {rng:replaced_rng}
valid = replaced_rng < rng.vmax+1
for b in self.bufs:
if rng in (i:=b.src[1].get_idx()).backward_slice_with_self:
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
elif opt.op is OptOps.SWAP:
try:
altrng:UOp = self.rngs[opt.arg]
except IndexError:
raise KernelOptError
check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals")
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)},
name=f"swap {rng.arg[:-1]} {altrng.arg[:-1]}")
self.ast = graph_rewrite(self.ast, remove_tags, name="swap remove tags")
else:
raise KernelOptError(f"unsupported opt {opt.op}")
if append_opt: self.applied_opts.append(opt)
return ret
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
if not (reduceops := self.reduceops): raise KernelOptError("no reduce ops for TensorCore")
reduceop = reduceops[0]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
if mul.op is not Ops.MUL: return None
in0, in1 = mul.src
try:
tensor_cores = self.ren.tensor_cores if tc_select == -1 else [self.ren.tensor_cores[tc_select]]
except IndexError:
raise KernelOptError(f"invalid tensor core choice {tc_select}")
for tc in tensor_cores:
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
# pick ranges
# NOTE: why are in1 and in0 switched?
axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges))
if not (axis < len(axis_choices)): continue
axes = list(axis_choices[axis])
# tag the reduceop
self.ast = self.ast.substitute({reduceop: reduceop.replace(tag="TC")})
# do optimizations and save the ranges
try:
for i,a in enumerate(axes):
idx = self.rngs.index(a)
if (a.vmax+1) % tc.dims[i] != 0:
if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2")
# apply_opt should return the updated range?
self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail
axes[i] = self.rngs[idx]
except KernelOptError: continue
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
ne: list[UOp] = []
for opt in tc.opts:
if opt[0] == "l":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
warp //= 2
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
ne.append(new_range)
for _, amt in tc.get_reduce_axes():
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
if use_tensor_cores != 2:
# fix the srcs
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
tne = [x.replace(tag=1) for x in ne]
ret = reduceop.substitute(dict(zip(ne, tne)))
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
# construct the op
# TODO: remove tc_upcast_axes from the arg
# do the reduce_axes always disappear? i think they don't
# they need to be moved into the WMMA srcs
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.ren.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes)
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
# preserve extra reduces
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
self.ast = self.ast.substitute({reduceop: tc_uop})
return axes
return None
# helpers for hand_coded_optimizations
@property
def reduceops(self) -> list[UOp]: return [x for x in self.ast.backward_slice if x.op is Ops.REDUCE]
@property
def reduceop(self) -> UOp|None:
if not (red := self.reduceops): return None
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
@property
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]
@property
def output_shape(self):
return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)]
@property
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
@property
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls]
def apply_opts(ast:UOp, ren:Renderer) -> UOp:
if ast.tag is not None: return ast
k = Scheduler(ast, ren)
k.convert_loop_to_global()
if ast.arg is not None and ast.arg.opts_to_apply is not None:
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
elif BEAM >= 1:
from tinygrad.codegen.opt.search import beam_search
rawbufs = bufs_from_ast(ast, ren.device)
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice):
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)