delete grouper and kernelize (#12517)

* delete grouper and kernelize

* +sys.setrecursionlimit
This commit is contained in:
qazal
2025-10-08 12:27:26 +03:00
committed by GitHub
parent 942022c309
commit 7e0b14243e
9 changed files with 13 additions and 514 deletions

View File

@@ -81,7 +81,6 @@ print("******** third, the UOp ***********")
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.helpers import RANGEIFY from tinygrad.helpers import RANGEIFY
from tinygrad.schedule.kernelize import get_kernelize_map
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
# allocate some values + load in values # allocate some values + load in values
@@ -95,7 +94,7 @@ out = a + b
s = UOp(Ops.SINK, dtypes.void, (out,)) s = UOp(Ops.SINK, dtypes.void, (out,))
# group the computation into kernels # group the computation into kernels
becomes_map = get_rangeify_map(s) if RANGEIFY else get_kernelize_map(s) becomes_map = get_rangeify_map(s)
# the compute maps to an assign # the compute maps to an assign
assign = becomes_map[a+b].base assign = becomes_map[a+b].base

View File

@@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
Group UOps into kernels. Group UOps into kernels.
::: tinygrad.schedule.kernelize.get_kernelize_map ::: tinygrad.schedule.rangeify.get_rangeify_map
options: options:
members: false members: false
show_labels: false show_labels: false

View File

@@ -2,9 +2,7 @@ import sys
from tinygrad import Tensor, fetch, GlobalCounters, dtypes from tinygrad import Tensor, fetch, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
from tinygrad.nn.onnx import OnnxRunner from tinygrad.nn.onnx import OnnxRunner
from tinygrad.schedule.kernelize import get_kernelize_map
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.helpers import RANGEIFY
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
@@ -35,7 +33,7 @@ if __name__ == "__main__":
if not in_target_path[s]: if not in_target_path[s]:
independent_set[s] = None independent_set[s] = None
independent = UOp.sink(*independent_set.keys()) independent = UOp.sink(*independent_set.keys())
kernelized = (get_rangeify_map if RANGEIFY else get_kernelize_map)(independent) kernelized = get_rangeify_map(independent)
independent = independent.substitute(kernelized) independent = independent.substitute(kernelized)
schedule, var_vals = create_schedule_with_vars(independent) schedule, var_vals = create_schedule_with_vars(independent)
run_schedule(schedule) run_schedule(schedule)

View File

@@ -12,10 +12,11 @@ from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel from tinygrad.codegen.opt.swizzler import merge_views
from tinygrad.schedule.rangeify import get_rangeify_map, Kernel
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
@@ -29,7 +30,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
else: else:
assert isinstance(t, UOp), f"can't schedule {t}" assert isinstance(t, UOp), f"can't schedule {t}"
sink = UOp.sink(t) if t.op is not Ops.SINK else t sink = UOp.sink(t) if t.op is not Ops.SINK else t
becomes_map = get_kernelize_map(sink) becomes_map = get_rangeify_map(sink)
sched, _ = create_schedule_with_vars(sink.substitute(becomes_map)) sched, _ = create_schedule_with_vars(sink.substitute(becomes_map))
# test lowering all the ScheduleItems to ExecItems # test lowering all the ScheduleItems to ExecItems
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
@@ -68,9 +69,6 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(name=True)
def schedule_graph_rewrite(big_sink:UOp): return get_kernelize_map(big_sink)[big_sink]
class TestSchedule(unittest.TestCase): class TestSchedule(unittest.TestCase):
def test_arange_avgpool2d(self, kcount=1): def test_arange_avgpool2d(self, kcount=1):
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32) x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
@@ -2244,17 +2242,11 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.empty(4).uop a = Tensor.empty(4).uop
b = a.copy_to_device(a.device) b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False) check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
# NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER
# in the scheduler because buffer already has shape (4,)
self.assertIs(b, a.base)
def test_copy_to_same_device_alt(self): def test_copy_to_same_device_alt(self):
a = Tensor.empty(4, 4).uop a = Tensor.empty(4, 4).uop
b = a.copy_to_device(a.device) b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False) check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
self.assertIs(b.base, a.base)
def test_copy_to_same_device_sched(self): def test_copy_to_same_device_sched(self):
a = Tensor.ones(4).contiguous().realize().uop.as_buf() a = Tensor.ones(4).contiguous().realize().uop.as_buf()

View File

@@ -2,7 +2,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
from tinygrad.helpers import all_same, prod, unwrap, colored from tinygrad.helpers import all_same, prod, unwrap, colored
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS from tinygrad.schedule.rangeify import ALWAYS_CONTIGUOUS
from tinygrad.dtype import ImageDType, dtypes from tinygrad.dtype import ImageDType, dtypes
merge_views = PatternMatcher([ merge_views = PatternMatcher([

View File

@@ -1,119 +0,0 @@
from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
# **** Grouper decides which of the UOps realize
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# always realize unsafe pad ops before masked view
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
# realize before expand
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(r.st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children.get(tr, {}):
# max one reduceop per kernel
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
realizes: dict[UOp, None] = {}
sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
if DONT_GROUP_REDUCES: return realizes
# construct children graph (only for bases)
children: dict[UOp, dict[UOp, None]] = {}
assigns: dict[UOp, None] = {}
for u in (toposort:=sink.toposort()):
if u.op in {Ops.VIEW, Ops.SINK}: continue
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
for s in u.src: children.setdefault(s.base, {})[u] = None
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
for r in toposort:
if r.op is not Ops.REDUCE_AXIS: continue
if len(r.arg) == 3 and r.arg[2] is True: continue
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
if r in realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
for u in r.toposort(gate=lambda u: u not in realizes):
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
can_chase = False
break
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
parents = [r, *group]
while parents and not forced_realize:
p = parents.pop().base
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
if p in realizes: continue
parents.extend(p.src)
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(tr.st)
while len(lst:=children.get(tr, {})) == 1:
tr_next = next(iter(lst))
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
tr = tr.src[0].base
group = {tr: None}
realizes[tr] = None
reduce_for_op.update((tr, r) for tr in group)
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = reduceop.src[0].base
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
return realizes

View File

@@ -1,374 +0,0 @@
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.schedule.rangeify import Kernel
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
from tinygrad.codegen.opt import Opt
# creation can recurse a lot
import sys
sys.setrecursionlimit(10000)
# **** schedule simplifier
def simplify_stride0_reduce(reduce:UOp, x:UOp):
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
if any(v.mask is not None for v in unwrap(x.st).views): return None
# must have all stride 0 in the relevant axis (NOTE: can do partial)
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
prshape = prod(x.shape[i] for i in reduce.arg[1])
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
match reduce.arg[0]:
case Ops.ADD: return ret*prshape
case Ops.MUL: return ret.pow(prshape)
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
def split_reduceop(reduce:UOp, x:UOp):
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
kernelize_sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
# non device changing COPY is a NOOP
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# CAST before masking constants
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
# make things that can't be images not images
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer/copy/assign is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
(t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# double ASSIGN to same target is one ASSIGN
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
# ASSIGN to unrealized replaces the UOp
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# put UnaryOps before EXPANDs, if it can fuse with the input
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# **** create kernels
def create_kernel(x:UOp, b:UOp|None=None):
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
# we have to shrink the buffer back to the symbolic shape
return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape))
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
def append_to_kernel(x:UOp):
new_srcs: list[UOp] = []
metadata = x.arg.metadata
for s in x.src:
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
else:
new_srcs.extend(s.src)
# NOTE: because const and device are shared UOps they don't change metadata
# NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
create_kernels = PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
# walk back the local graph until we reach a realized source
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# push RESHAPE through MSELECT
(UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
# push RESHAPE through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
])
def add_stores(ctx, sink: UOp):
stores = []
for i,x in enumerate(sink.src):
gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i)
# if this is an assign then we already have a buffer with a view that should be the target of the store
if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s))
# otherwise we have to create the shapetracker and shrink it to the correct symbolic shape
else: stores.append(
UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s))
return UOp.sink(*stores, arg=sink.arg)
# **** fix kernel AST
def unbind_view(x:UOp):
if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0])
return None
replace_buffers = PatternMatcher([
# sink on contig creates a KernelInfo
(UPat(Ops.CONTIGUOUS, name="c").sink(name="s"),
lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \
if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None),
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# passthrough ASSIGN (but let MSTACK process first)
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]),
# remove any BINDs from VIEWS
(UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])),
# remove any BINDs from DEFINE_VARs
(UPat(Ops.BIND, name="x"), lambda x: x.src[0]),
# remove BINDs from ShapeTrackers
(UPat(Ops.VIEW, name="x"), unbind_view),
])
def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace buffer with define_global + add load/store last
bufs = []
for s in k.src:
if s.op is Ops.BIND: continue
s = s.buf_uop
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
# replace global memory ops with the BUFFER they write to
# NOTE: merge_views is needed to unbind the reshapes
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([
(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),
(UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())),
])
# ** add metadata of KERNEL outputs
def append_metadata(root:UOp, k:UOp):
if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None
return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:])
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
pm_fuse = PatternMatcher([
# FUSE on CONTIGUOUS removes FUSE
(UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
# FUSE triggers swizzle on reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
# FUSE on reduce (without view) adds fuse marker to grouper
(UPat(Ops.REDUCE_AXIS, name="r").fuse(),
lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
# remove FUSE and insert CONTIGUOUS if it's an unsafe pad
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
# FUSE elementwise.
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
# push FUSE through to srcs
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
])
def do_fusion(x:UOp):
found_contiguous = {}
def gate_contiguous(x):
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
return not is_contiguous
x.toposort(gate=gate_contiguous)
del gate_contiguous
return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
def fuse_arange(root:UOp):
# skip if root is arange
if root.src[0].base.op is Ops.CONST: return None
# gather all local aranges (including any fused ones)
local_arange: list[UOp] = []
def gate_reduce(u):
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
toposort = root.toposort(gate=gate_reduce)
if not local_arange: return None
# fuse the nearest expand child of arange
local_children: dict[UOp, list[UOp]] = {}
for u in toposort:
for s in u.src: local_children.setdefault(s, []).append(u)
fuse_rep: dict[UOp, UOp] = {}
for r in local_arange:
# skip if already fused
if len(r.arg) > 2: continue
q = list(local_children[r])
while q:
u = q.pop()
if not (curr_children:=local_children.get(u, [])): continue
for child in curr_children:
other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
if other_paths: break
else: q.extend(curr_children)
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
do_fuse = PatternMatcher([
(UPat(Ops.FUSE, name="x"), do_fusion),
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
])
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
# TODO: get this from the device through GrouperOpts
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
def limit_bufs(root:UOp):
# check if backend has a buffer limit
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
# count number of unique buffers flowing into this op
bufs: set[UOp] = set()
def gate_input(u:UOp):
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK, Ops.DEFINE_VAR})): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
# NOTE: this -1 is for the output buffer
if len(bufs)>=MAX_BUFS-1:
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
def view_add_srcs(x:UOp):
if len(avars:=x.arg.vars()) and len(x.src) == 1:
return x.replace(src=x.src+tuple(avars))
return None
finalize_contiguous = PatternMatcher([
# if an op takes more than one input, check combined LOADs don't exceed device limits
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
# merge contiguous
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
# simplify views
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
# vars to views srcs
(UPat(Ops.VIEW, name="x"), view_add_srcs),
])
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
Args:
sink: The Ops.SINK rooting the Tensor graph.
Returns:
Map transforming each UOp in the sink to the Ops.KERNEL graph.
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
# insert contiguous in places determined by the realize map
realize_map = group_realizes(tensor_map[sink])
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
# group into kernels (this is context-free)
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in tensor_map[sink].toposort():
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
# finally, create the AST for kernels
tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast")
# display the final graph
sched_sink = tensor_map[sink]
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
# verify Kernels match the spec
if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)
return tensor_map

View File

@@ -10,6 +10,10 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
# creation can recurse a lot
import sys
sys.setrecursionlimit(10000)
# ***************** # *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff # 0. do some cleanup rewrites, mostly copied from the old stuff

View File

@@ -18,7 +18,6 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map from tinygrad.schedule.multi import get_multi_map
from tinygrad.schedule.kernelize import get_kernelize_map
# *** all in scope Tensors are here. this gets relevant UOps *** # *** all in scope Tensors are here. this gets relevant UOps ***
@@ -232,7 +231,7 @@ class Tensor(MathTrait):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst])) big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
becomes_map = get_rangeify_map(big_sink) if RANGEIFY else get_kernelize_map(big_sink) becomes_map = get_rangeify_map(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map") _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
return self return self