scheduling graph_rewrite prereqs for BLOCK in ASSIGN (#8598)

* remove the BUF_LIMIT assert

* skip the base one

* work

* work

* good error

* ok comment

* shorter check
This commit is contained in:
qazal
2025-01-14 03:01:59 -05:00
committed by GitHub
parent 05e54f00d3
commit 863abc7140
2 changed files with 29 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
# compare kernels created by HEAD against master
from collections import defaultdict
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Set, Tuple, Union, cast
from typing import Callable, List, Tuple, Union, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
from tinygrad.codegen.kernel import Kernel, Opt
@@ -30,9 +30,9 @@ class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(ast:UOp, assigns:Set[UOp]) -> UOp:
def recreate_sched(ast:UOp) -> UOp:
# NOTE: process replay isn't meant to actually schedule anything
return schedule_uop(ast, ScheduleContext(assigns=assigns, tensor_uops=defaultdict(list))).ast
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)

View File

@@ -183,14 +183,9 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
ops_metadata: dict[UOp, Metadata]
assigns: set[UOp]
var_vals: dict[Variable, int]
sinked: dict[UOp, UOp]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
metadata: set[Metadata] = field(default_factory=set)
assign_adj: dict[UOp, list[UOp]] = field(default_factory=dict)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
@@ -204,47 +199,47 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
(adj_loads:=ctx.assign_adj.setdefault(b, [])).append(x)
if not all_same([x.op for x in adj_loads]): raise RuntimeError(f"Detected cycle when fusing {adj_loads}. Can only fuse PRELOAD or LOAD of {b}")
return x.replace(op=Ops.LOAD)
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
])
add_metadata = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: None if (m:=ctx.ops_metadata.get(x)) is None else ctx.metadata.add(m)),])
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
if b in ctx.assigns else None)])
# late folding for multi output kernels
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# create the ast context
si_ctx = ScheduleItemContext(ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
# do movement ops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si+check_preload if len(si_ctx.assigns) != 0 else to_si, si_ctx), append_bufs, si_ctx)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
for ubuf,ops in si_ctx.assign_adj.items():
if si_ctx.sinked.get(ubuf) is not None and not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# remove movement ops + substitute LOAD of fused STORE with just the value
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si+append_bufs, si_ctx:=ScheduleItemContext(ctx.var_vals))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, sink))
# deal with ASSIGN
assign_preloads: list[UOp] = []
if len(ctx.assigns) != 0:
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads.append(x.buf_uop)
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY: