mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
* scheduler diff * tests diff * new changes * realizes * chores * assign * kind of r3 * forced_realize wont do it * with forced_realize * start with children * test search * r3 with parents * diff cleanup * add children * crossing assign * late fuse descendants * update kernel counts * assign diff doesnt belong here
290 lines
15 KiB
Python
290 lines
15 KiB
Python
import sys, pickle, atexit
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass
|
|
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
|
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
|
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
|
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
|
|
from tinygrad.shape.symbolic import Variable
|
|
from tinygrad.dtype import ImageDType, dtypes
|
|
from tinygrad.lazy import LazyBuffer
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
# creation can recurse a lot
|
|
sys.setrecursionlimit(10000)
|
|
|
|
# optionally log the ops to disk
|
|
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
|
|
|
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
|
|
@dataclass(frozen=True)
|
|
class _LBScheduleItem:
|
|
ast: Tuple[LazyOp, ...]
|
|
outputs: Tuple[LazyBuffer, ...]
|
|
inputs: Tuple[LazyBuffer, ...]
|
|
var_vals: Dict[Variable, int]
|
|
|
|
# recursively create a lazyop
|
|
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
|
realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
|
if (buf, st) in cache: return cache[(buf, st)]
|
|
if buf != buf.base:
|
|
st = buf.st + st
|
|
buf = buf.base
|
|
# all buffers here are base now
|
|
assert buf.op is not None
|
|
|
|
# consts are always fused and generated
|
|
if buf.op is LoadOps.CONST:
|
|
unbound_st, st_var_vals = st.simplify().unbind()
|
|
var_vals.update(st_var_vals)
|
|
if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
|
|
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
|
|
|
|
# if we aren't fusing it, it's a load and we add it to the inputs
|
|
if buf.realized or (buf in realizes and buf not in outbufs):
|
|
unbound_st, st_var_vals = st.simplify().unbind()
|
|
var_vals.update(st_var_vals)
|
|
if assign_to is not None and buf is assign_to:
|
|
assert assign_idx is not None
|
|
if not unbound_st.contiguous:
|
|
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
|
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
|
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
|
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
|
|
if buf not in inputs: inputs.append(buf)
|
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outbufs)+inputs.index(buf), buf.dtype, unbound_st))
|
|
|
|
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
|
if buf.op is LoadOps.CONTIGUOUS:
|
|
assert buf in outbufs
|
|
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache)
|
|
if buf.op is LoadOps.ASSIGN:
|
|
assert buf in outbufs
|
|
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
|
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
|
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=outbufs.index(buf))
|
|
|
|
# if it's a reduce, we have to change the shapetracker
|
|
if buf.op in ReduceOps:
|
|
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
|
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
|
|
|
# otherwise we fuse it like normal
|
|
cache[(buf, st)] = ret = \
|
|
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
|
return ret
|
|
|
|
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
|
inputs: List[LazyBuffer] = []
|
|
ast: List[LazyOp] = []
|
|
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
|
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
|
|
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], [x.base for x in outs[0].srcs]
|
|
else:
|
|
for i, out in enumerate(outs):
|
|
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
|
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
|
op = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={})
|
|
output_view, vv = output_view.simplify().unbind()
|
|
if vv: var_vals.update(vv)
|
|
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view)))
|
|
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
|
|
|
|
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
|
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
|
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
|
if buf in allbufs or buf.base.realized: return
|
|
if GRAPH: log_lazybuffer(buf, scheduled)
|
|
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
|
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
|
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
|
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
|
|
# hack the underlying buffer too
|
|
if buf.base is buf:
|
|
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
|
buf.buffer.dtype = dtypes.float32
|
|
buf.buffer.options = None
|
|
if buf.base != buf:
|
|
# realize all places where the buffer is expanded
|
|
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
|
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
|
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
|
simple_pads.add(buf.base)
|
|
else:
|
|
realizes[buf.base] = None
|
|
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
|
if buf.forced_realize: realizes[buf] = None
|
|
allbufs[buf] = None
|
|
if buf.op in LoadOps: realizes[buf.base] = None
|
|
if buf.op is LoadOps.COPY:
|
|
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
|
realizes[buf.srcs[0].base] = None
|
|
if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
|
|
for x in buf.srcs:
|
|
children[x.base][buf] = None
|
|
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
|
|
|
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
|
if buf in realizes or buf.realized: return True
|
|
# NOTE: this broke to_image_idx and coder with JIT
|
|
if buf.op in UNSAFE_PAD_OPS: return False
|
|
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
|
|
|
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
|
|
Dict[LazyBuffer, _LBScheduleItem]]:
|
|
# start by just realizing the buffers passed in
|
|
realizes: Dict[LazyBuffer, None] = {x.base: None for x in outs if not x.base.realized}
|
|
allbufs: Dict[LazyBuffer, None] = {}
|
|
simple_pads: Set[LazyBuffer] = set()
|
|
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
|
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
|
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
|
|
|
|
# check if we have to realize pads
|
|
for p in simple_pads:
|
|
if not _is_padding_okay(p, realizes):
|
|
realizes[p] = None
|
|
|
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
|
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
|
for r in allbufs.keys():
|
|
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
|
|
|
# follow the reduce down
|
|
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
|
|
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
|
|
forced_realize = False
|
|
can_chase = True
|
|
while not forced_realize and len(child_set):
|
|
next_child_set = {}
|
|
for tr,st in child_set.items():
|
|
if tr in realizes:
|
|
realized_children[tr] = st
|
|
# can only reduce contiguous
|
|
# max one reduceop per kernel
|
|
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op:
|
|
can_chase = tr not in reduce_for_op
|
|
forced_realize = True
|
|
break
|
|
if len(realized_children) > 1:
|
|
rc_parents, rc_children = deque(realized_children), deque(realized_children)
|
|
while rc_parents and not forced_realize:
|
|
# max one reduceop per kernel
|
|
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
|
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
|
realized_descendants: Set[LazyBuffer] = set()
|
|
while rc_children and not forced_realize:
|
|
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
|
realized_descendants.clear()
|
|
break
|
|
if c in realizes and c not in (*realized_children, tr): realized_descendants.add(c)
|
|
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
|
realized_children.update((rd, st) for rd in realized_descendants)
|
|
continue
|
|
for tr_next in children[tr].keys():
|
|
if not tr_next.realized:
|
|
# max one reduceop per kernel
|
|
if tr_next.op in ReduceOps:
|
|
forced_realize = True
|
|
break
|
|
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
|
if len(st_childs) > 1:
|
|
forced_realize = True
|
|
break
|
|
next_child_set[tr_next] = st + st_childs[0].st
|
|
child_set = next_child_set
|
|
if not forced_realize and any(x.op is LoadOps.ASSIGN for x in realized_children):
|
|
parents = deque((r, *realized_children))
|
|
while parents and not forced_realize:
|
|
if (p:=parents.pop().base).realized or p in realizes:
|
|
if p in assign_targets and assign_targets[p] not in realized_children: forced_realize, can_chase = True, False
|
|
continue
|
|
parents.extend(p.srcs)
|
|
if forced_realize:
|
|
tr = r
|
|
if can_chase:
|
|
# can chase this down to contiguous children
|
|
st = tr.st
|
|
while len(children[tr]) == 1:
|
|
tr_next = next(iter(children[tr].keys()))
|
|
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
|
if len(st_childs) > 1: break
|
|
if st.size != st_childs[0].st.size: break
|
|
st = st + st_childs[0].st
|
|
if not st.contiguous or tr_next.op in ReduceOps: break
|
|
tr = tr_next
|
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
|
if tr.op is UnaryOps.CAST and tr.arg[0].itemsize > tr.srcs[0].dtype.itemsize:
|
|
tr = tr.srcs[0].base
|
|
reduce_for_op[tr] = r
|
|
realizes[tr] = None
|
|
else: reduce_for_op.update((tr, r) for tr in realized_children)
|
|
|
|
output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list)
|
|
for r in realizes:
|
|
if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue
|
|
output_groups[(reduce_for_op[r], ) if r in reduce_for_op and MULTIOUTPUT else (r, )].append(r)
|
|
|
|
# preschedule all buffers in realizes
|
|
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
|
|
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
|
|
|
|
# breadth first ordering
|
|
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
|
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
|
|
for key, lsi in prescheduled.items():
|
|
if key not in in_degree: in_degree[key] = 0
|
|
# realize outputs after all parents are realized
|
|
scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
|
|
for x in scheduled_parents:
|
|
graph[x].append(key)
|
|
in_degree[key] += 1
|
|
# realize outputs before a parent is assigned to
|
|
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
|
|
for assign in parents_assigns:
|
|
graph[key].append(assign)
|
|
in_degree[assign] += 1
|
|
|
|
return graph, in_degree, prescheduled
|
|
|
|
SCHEDULES: List = []
|
|
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
|
if seen is None: seen = set()
|
|
graph, in_degree, prescheduled = _graph_schedule(outs, seen)
|
|
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
|
|
schedule: List[ScheduleItem] = []
|
|
var_vals: Dict[Variable, int] = {}
|
|
kernel_number = GlobalCounters.kernel_count
|
|
while queue:
|
|
ps = queue.popleft()
|
|
for buf in ps.outputs: seen.add(buf)
|
|
if GRAPH:
|
|
kernel_number += 1
|
|
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
|
var_vals = merge_dicts([var_vals, ps.var_vals])
|
|
for out in ps.outputs: del out.srcs # can only schedule once
|
|
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
|
|
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
|
for x in graph[ps.outputs[0]]:
|
|
in_degree[x] -= 1
|
|
if in_degree[x] == 0: queue.append(prescheduled[x])
|
|
|
|
if SAVE_SCHEDULE:
|
|
def _save():
|
|
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
|
pickle.dump(SCHEDULES, open(fp, "wb"))
|
|
if len(SCHEDULES) == 0: atexit.register(_save)
|
|
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
|
# confirm everything was scheduled correctly
|
|
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
|
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
|
return schedule, var_vals
|
|
|
|
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
|
schedule, var_vals = create_schedule_with_vars(outs, seen)
|
|
assert len(var_vals) == 0
|
|
return schedule
|