mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
scheduling graph_rewrite prereqs for BLOCK in ASSIGN (#8598)
* remove the BUF_LIMIT assert * skip the base one * work * work * good error * ok comment * shorter check
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# compare kernels created by HEAD against master
|
||||
from collections import defaultdict
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, List, Set, Tuple, Union, cast
|
||||
from typing import Callable, List, Tuple, Union, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
@@ -30,9 +30,9 @@ class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(ast:UOp, assigns:Set[UOp]) -> UOp:
|
||||
def recreate_sched(ast:UOp) -> UOp:
|
||||
# NOTE: process replay isn't meant to actually schedule anything
|
||||
return schedule_uop(ast, ScheduleContext(assigns=assigns, tensor_uops=defaultdict(list))).ast
|
||||
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
|
||||
@@ -183,14 +183,9 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
ops_metadata: dict[UOp, Metadata]
|
||||
assigns: set[UOp]
|
||||
var_vals: dict[Variable, int]
|
||||
sinked: dict[UOp, UOp]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
metadata: set[Metadata] = field(default_factory=set)
|
||||
assign_adj: dict[UOp, list[UOp]] = field(default_factory=dict)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
@@ -204,47 +199,47 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
(adj_loads:=ctx.assign_adj.setdefault(b, [])).append(x)
|
||||
if not all_same([x.op for x in adj_loads]): raise RuntimeError(f"Detected cycle when fusing {adj_loads}. Can only fuse PRELOAD or LOAD of {b}")
|
||||
return x.replace(op=Ops.LOAD)
|
||||
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
|
||||
|
||||
to_si = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
# PRELOAD becomes LOAD
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
])
|
||||
|
||||
add_metadata = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: None if (m:=ctx.ops_metadata.get(x)) is None else ctx.metadata.add(m)),])
|
||||
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
|
||||
if b in ctx.assigns else None)])
|
||||
|
||||
# late folding for multi output kernels
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# create the ast context
|
||||
si_ctx = ScheduleItemContext(ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
|
||||
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
|
||||
# do movement ops
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# convert to AST
|
||||
sink = graph_rewrite(graph_rewrite(sink, to_si+check_preload if len(si_ctx.assigns) != 0 else to_si, si_ctx), append_bufs, si_ctx)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
for ubuf,ops in si_ctx.assign_adj.items():
|
||||
if si_ctx.sinked.get(ubuf) is not None and not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# remove movement ops + substitute LOAD of fused STORE with just the value
|
||||
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si+append_bufs, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
|
||||
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
|
||||
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, sink))
|
||||
# deal with ASSIGN
|
||||
assign_preloads: list[UOp] = []
|
||||
if len(ctx.assigns) != 0:
|
||||
for x in list(sink.toposort)[::-1]:
|
||||
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
||||
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
||||
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads.append(x.buf_uop)
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
|
||||
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
|
||||
Reference in New Issue
Block a user