mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
scheduler cleanups + better cycle assert [pr] (#9172)
* scheduler cleanups + better cycle assert [pr] * type_verify after assign fixup * don't need base * always realize sink parents
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import sys, functools, atexit, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
@@ -37,11 +37,6 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
def fix_image(u:UOp):
|
||||
if isinstance(dt:=u.dtype, ImageDType) and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in unwrap(u.st).unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {u.dtype} with shape {u.shape} to {u.dtype.base}")
|
||||
return u.replace(dtype=u.dtype.base)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
@@ -62,7 +57,8 @@ sym = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# make things that can't be images not images
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: fix_image(u) if isinstance(u.dtype, ImageDType) else None),
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
||||
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
@@ -88,9 +84,9 @@ remove_movement_ops = merge_views+PatternMatcher([
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrouperContext:
|
||||
assigns: dict[UOp, UOp] = field(default_factory=dict) # this holds all the ASSIGN UOps in this schedule
|
||||
realizes: dict[UOp, None] = field(default_factory=dict) # this holds all the Tensor uops we realize in this schedule
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
||||
realizes: dict[UOp, None] # all the simplified tensor uops we realize
|
||||
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
|
||||
|
||||
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
|
||||
|
||||
@@ -106,8 +102,7 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"),
|
||||
lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
@@ -141,7 +136,7 @@ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"),
|
||||
|
||||
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
||||
# start by adding uops that always realize
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext())
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: dict[UOp, UOp] = {}
|
||||
double_reduces: list[UOp] = []
|
||||
@@ -332,8 +327,8 @@ fix_kernel_ops = PatternMatcher([
|
||||
])
|
||||
|
||||
def load_buf(ctx:list[UOp], x:UOp):
|
||||
if x.base not in ctx: ctx.append(x.base)
|
||||
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.base.size), (), ctx.index(x.base)), unwrap(x.st).to_uop()))
|
||||
if x not in ctx: ctx.append(x)
|
||||
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
@@ -420,9 +415,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
for s in u.src[1].src:
|
||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
||||
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
|
||||
if assign_rep:
|
||||
sched_sink = sched_sink.substitute(assign_rep)
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
|
||||
# display the final graph
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
@@ -451,7 +448,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
# confirm everything was scheduled correctly
|
||||
if len(schedule) != (groups:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
||||
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
|
||||
Reference in New Issue
Block a user