scheduler cleanups + better cycle assert [pr] (#9172)

* scheduler cleanups + better cycle assert [pr]

* type_verify after assign fixup

* don't need base

* always realize sink parents
This commit is contained in:
qazal
2025-02-19 14:30:58 +02:00
committed by GitHub
parent cf315d544b
commit e4a8bf28ea

View File

@@ -1,6 +1,6 @@
import sys, functools, atexit, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
@@ -37,11 +37,6 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
def fix_image(u:UOp):
if isinstance(dt:=u.dtype, ImageDType) and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in unwrap(u.st).unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {u.dtype} with shape {u.shape} to {u.dtype.base}")
return u.replace(dtype=u.dtype.base)
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
@@ -62,7 +57,8 @@ sym = symbolic_simple+PatternMatcher([
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# make things that can't be images not images
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: fix_image(u) if isinstance(u.dtype, ImageDType) else None),
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
@@ -88,9 +84,9 @@ remove_movement_ops = merge_views+PatternMatcher([
@dataclass(frozen=True)
class GrouperContext:
assigns: dict[UOp, UOp] = field(default_factory=dict) # this holds all the ASSIGN UOps in this schedule
realizes: dict[UOp, None] = field(default_factory=dict) # this holds all the Tensor uops we realize in this schedule
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
assigns: dict[UOp, UOp] # maps realized buffers to assigns
realizes: dict[UOp, None] # all the simplified tensor uops we realize
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
@@ -106,8 +102,7 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"),
lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize before expand or unsafe pad ops
@@ -141,7 +136,7 @@ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"),
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext())
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
@@ -332,8 +327,8 @@ fix_kernel_ops = PatternMatcher([
])
def load_buf(ctx:list[UOp], x:UOp):
if x.base not in ctx: ctx.append(x.base)
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.base.size), (), ctx.index(x.base)), unwrap(x.st).to_uop()))
if x not in ctx: ctx.append(x)
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
add_buffer_ops = PatternMatcher([
# LOAD
@@ -420,9 +415,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
for s in u.src[1].src:
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
if assign_rep:
sched_sink = sched_sink.substitute(assign_rep)
type_verify(list(sched_sink.toposort), kernel_spec)
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
@@ -451,7 +448,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# capture process replay
if CAPTURE_PROCESS_REPLAY: