don't graph_rewrite into calls (#14931)

* don't graph_rewrite into calls

* optional

* pm_gate_kernel_sink removed
This commit is contained in:
George Hotz
2026-02-21 15:39:59 +08:00
committed by GitHub
parent c5029fa460
commit 06fb35a1e5
5 changed files with 20 additions and 17 deletions

View File

@@ -2,9 +2,6 @@ from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewr
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort
# these are the only uops that can get replaced in the tensor graph
from tinygrad.schedule.rangeify import pm_gate_kernel_sink
def tag_uop(ctx:tuple[list[UOp], set[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
if x.tag is not None or x in ctx[1]: return None
if x.tag is None and x.op is Ops.CALL:
@@ -25,7 +22,7 @@ def apply_after(ctx, u):
ctx[2][u] = u.src[0]
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
add_tags = pm_gate_kernel_sink+PatternMatcher([
add_tags = PatternMatcher([
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
(UPat(Ops.AFTER, name="u"), apply_after),
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),

View File

@@ -116,7 +116,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
else:

View File

@@ -3,7 +3,7 @@ import functools, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
@@ -21,7 +21,7 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
if buf.base in x.backward_slice_with_self: ctx[x] = None
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
pm_generate_realize_map = PatternMatcher([
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
@@ -76,11 +76,14 @@ mop_cleanup = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
params = sorted(params, key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
@@ -486,9 +489,8 @@ def get_rangeify(sink:UOp) -> UOp:
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
name="bufferize to store")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, bottom_up=True, name="split kernels")
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDT
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
from tinygrad.helpers import strip_parens, colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.renderer import Estimates
@@ -1229,12 +1229,13 @@ if TRACK_MATCH_STATS or PROFILE:
SENTINEL: Final[UOp] = cast(UOp, object())
class BottomUpGate(Exception): pass
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False):
self.pm: PatternMatcher|None = pm
self.bpm: PatternMatcher|None = bpm
self.bpm_cache: dict[UOp, UOp|None] = {}
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
self.rewrite_into_calls = rewrite_into_calls
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
@@ -1269,6 +1270,10 @@ class RewriteContext:
if n in waitlist: stack.extend(waitlist.pop(n))
continue
stack.append((n, 1, new_n))
# NOTE: CALL is handled as a special case.
# The function that is called is not included in the graph_rewrite.
# If you want to graph_rewrite a call, you can
if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0]
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
@@ -1307,8 +1312,8 @@ class RewriteContext:
return self.replace[root]
@profile_matches
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls)
return rewrite_ctx.unified_rewrite(sink)
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
@@ -1343,7 +1348,6 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()