Files
tinygrad/tinygrad/engine/schedule.py

422 lines
23 KiB
Python

import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Tuple[Metadata, ...]
assign_preloads: Tuple[UOp, ...]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
@property
def inputs(self) -> Tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
# **** small wrapper for LazyBuffer -> UOp
@dataclass(frozen=True)
class ScheduleContext:
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule
lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
double_reduces:Dict[UOp, None], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, children, allbufs, double_reduces, cache).view(buf.st)
return ret
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
# hack the underlying buffer too
buf.dtype = buf.buffer.dtype = buf.dtype.base
assert not buf.is_realized(), "can't fixup allocated buffer"
buf.buffer.options = None
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# consts are always fused and generated
if buf.op is Ops.CONST:
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
# everything else has BUFFER
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
# if the buffer is already realized we just load it
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
# everything else needs sources
src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs)
if buf.op in {Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ret = UOp(buf.op, dtype, src, buf.arg)
elif buf.op is Ops.ASSIGN:
ctx.assigns.add(ubuf)
ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
else: ret = UOp(cast(Ops, buf.op), dtype, src)
if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,))
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
ctx.lazybufs[b] = buf
# things for fuse.py
allbufs[ubuf] = ret
if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None
for x in src:
if x.base.op is Ops.LOAD: children[x.base.buf_uop][ubuf] = None
return ret
# **** AST graph rewrite
# ** helpers for doing movementops on uops
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
# ** movementops rewrite rules
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
output_shape = swizzle_st.reduce(root.axis_arg)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.base is not x]
if len(swizzles) == 0: return None
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# ASSIGN can override st
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
# non contiguous VIEW on a reduce creates a new VIEW
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: Dict[Variable, int]
assigned: Set[UOp]
sts: Set[ShapeTracker] = field(default_factory=set)
bufs: List[UOp] = field(default_factory=list)
assign_preloads: List[UOp] = field(default_factory=list)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
ctx.var_vals.update(var_vals)
ctx.sts.add(st)
return st.to_uop() if st != x.st else None
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
if b in ctx.assigned: ctx.assign_preloads.append(b)
return x.replace(op=Ops.LOAD)
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,x: x),
])
# ** fusion
lazy = PatternMatcher([
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
# fuse and fold store -> loads
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
# assert cyclic dependency
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
if not all_same([x.op for x in ops]):
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
# do movementops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
return sink, ctx
PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay():
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
# **** Schedule grouping
def uval(u:UOp) -> UOp:
assert u.op is Ops.LOAD and len(u.src) == 3 and u.src[2].op is Ops.STORE, f"must be a LOAD of STORE {u}"
return to_store.src[0] if (to_store:=u.src[2].src[2]).is_contiguous_base else to_store
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(allbufs[r].st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if x.base.op is Ops.LOAD and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op is Ops.REDUCE_AXIS: return {}
rc_parents.extend(x.base.buf_uop for x in p.src if x.base.op is Ops.LOAD and x.base.buf_uop is not r)
# search descendants of the reduceop that can cleanly group
descendants: Dict[UOp, None] = {}
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def group_realizes(children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], double_reduces:Dict[UOp, None],
realizes:Dict[UOp, UOp], assigns:Set[UOp]) -> List[List[UOp]]:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[UOp, UOp] = {}
reduce_of_const: List[UOp] = []
for r, r_uop in allbufs.items():
if r in realizes or (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
group: Dict[UOp, None] = {}
recursive_group(r, unwrap(r_uop.st), r, children, allbufs, realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
if not forced_realize and len(group) > 1:
group = get_isolated_children(r, reduce_for_op, children, allbufs, realizes, group)
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x in assigns for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p_uop:=allbufs.get(p:=parents.pop())) is None: continue
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
if p in realizes: continue
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op in {Ops.LOAD, Ops.PRELOAD}])
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(r_uop.st)
while len(children[tr]) == 1:
tr_next_uop = uval(allbufs[(tr_next:=next(iter(children[tr])))])
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if x.base.op is Ops.LOAD and x.base.buf_uop is tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if (tr_uop:=uval(allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
tr = tr_uop.src[0].base.buf_uop
group = {tr: None}
realizes[tr] = tr
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r)
# fuse double reduces with no other child
if FUSE_CONV_BW:
for reduceop in double_reduces:
top_reduce = uval(allbufs[reduceop]).src[0].base.buf_uop
if len(children[top_reduce]) == 1: del realizes[top_reduce]
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(allbufs[tr].src[2].src[2].is_contiguous_base for tr in group): continue
kernel_children = {c for tr in group for c in children[tr] if uval(allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del realizes[tr]
output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
return list(output_groups.values())
# **** Schedule creation and BFS toposort
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
ctx[b] = store
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, **kwargs) -> Optional[UOp]:
base_shape = unwrap(base.st).shape
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
return None if can_pad(base) else realize(ctx, **kwargs).view(st)
# early realize before expand
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, **kwargs).view(st)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, **kwargs).view(st)
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
do_realize = PatternMatcher([
# always realize meta ops
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize),
# don't realize image to image casts
(UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
# realize before expand or unsafe pad ops
(UPatLoadStore(UPat.var("base")).view(name="view"), realize_view),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="view"))), name="root"),
lambda ctx,root,u,view=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st))),),
])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
for out in outs: out.forced_realize = True
# create the big graph
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
# **** TODO: delete these next 3 after big graph
children: DefaultDict[UOp, Dict[UOp, None]] = defaultdict(dict)
allbufs: Dict[UOp, UOp] = {}
double_reduces: Dict[UOp, None] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in outs))
# get realizes
realizes: Dict[UOp, UOp] = {}
graph_rewrite(big_graph, do_realize, realizes)
store_groups = group_realizes(children, allbufs, double_reduces, realizes, ctx.assigns)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, realizes)
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
# preschedule all realizes
bufs = list(ctx.buf_uops)
prescheduled: List[ScheduleItem] = []
for sink in sinks:
metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
for si in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(bufs[x.arg[0]])) and xsi is not si)
for assign in parents_assigns:
graph[si].append(assign)
in_degree[assign] += 1
# realize outputs after all parents are realized
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
for x in scheduled_parents:
graph[x].append(si)
in_degree[si] += 1
queue = deque(si for si in prescheduled if in_degree[si] == 0)
schedule: List[ScheduleItem] = []
while queue:
schedule.append(si:=queue.popleft())
for b in si.outputs: del ctx.lazybufs[b].srcs # can only schedule once
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
if DEBUG >= 3: print(si)
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, ctx.var_vals
def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs)
assert len(var_vals) == 0
return schedule