From 06fb35a1e5ddfc3dd09c03951aee547e8a06517f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 21 Feb 2026 15:39:59 +0800 Subject: [PATCH] don't graph_rewrite into calls (#14931) * don't graph_rewrite into calls * optional * pm_gate_kernel_sink removed --- tinygrad/engine/allocations.py | 5 +---- tinygrad/engine/schedule.py | 2 +- tinygrad/schedule/indexing.py | 4 ++-- tinygrad/schedule/rangeify.py | 12 +++++++----- tinygrad/uop/ops.py | 14 +++++++++----- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index c79bf0e70d..f63e88f2b6 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -2,9 +2,6 @@ from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewr from tinygrad.dtype import ImageDType from tinygrad.helpers import prod, DEBUG, argsort -# these are the only uops that can get replaced in the tensor graph -from tinygrad.schedule.rangeify import pm_gate_kernel_sink - def tag_uop(ctx:tuple[list[UOp], set[UOp], dict[UOp, UOp], set[UOp]], x:UOp): if x.tag is not None or x in ctx[1]: return None if x.tag is None and x.op is Ops.CALL: @@ -25,7 +22,7 @@ def apply_after(ctx, u): ctx[2][u] = u.src[0] # CONTIGUOUS and ASSIGN + parents are the only nodes that get updated -add_tags = pm_gate_kernel_sink+PatternMatcher([ +add_tags = PatternMatcher([ (UPat(Ops.COPY, name="u"), disk_copy_is_buffer), (UPat(Ops.AFTER, name="u"), apply_after), (UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 026de189f5..52c3182699 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -116,7 +116,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: # verify Tensors match the spec (on big_sink, we only need to do this if cache misses) if SPEC: type_verify(big_sink, tensor_spec) - big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm") + big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True) pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache)) if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink) else: diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 15919caa5d..60ee6882ce 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -3,7 +3,7 @@ import functools, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches -from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink +from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored @@ -21,7 +21,7 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): # you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce if buf.base in x.backward_slice_with_self: ctx[x] = None -pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([ +pm_generate_realize_map = PatternMatcher([ # always realize SINK src (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), # always realize diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 68a5d48878..df573f1743 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS @@ -76,11 +76,14 @@ mop_cleanup = PatternMatcher([ (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))), ]) +pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) def resolve_call(c:UOp) -> UOp|None: # don't resolve real kernel calls, sink or program if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None if c.src[0].op is Ops.PROGRAM: return None - params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) + params: list[UOp] = [] + graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params) + params = sorted(params, key=lambda x: x.arg) args = c.src[1:] # TODO: this check belongs in spec, not here if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}") @@ -486,9 +489,8 @@ def get_rangeify(sink:UOp) -> UOp: # bufferize -> store lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1 - tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, - name="bufferize to store") - tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, bottom_up=True, name="split kernels") + tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store") + tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels") # WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish afters = [u for u in tsink.toposort() if u.op is Ops.AFTER] diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 00d31e3955..a981a6ea8f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -8,7 +8,7 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDT from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY -from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic +from tinygrad.helpers import strip_parens, colored, ansilen, printable if TYPE_CHECKING: from tinygrad.device import Buffer, MultiBuffer from tinygrad.renderer import Estimates @@ -1229,12 +1229,13 @@ if TRACK_MATCH_STATS or PROFILE: SENTINEL: Final[UOp] = cast(UOp, object()) class BottomUpGate(Exception): pass class RewriteContext: - def __init__(self, pm, bpm, ctx=None): + def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False): self.pm: PatternMatcher|None = pm self.bpm: PatternMatcher|None = bpm self.bpm_cache: dict[UOp, UOp|None] = {} self.ctx = ctx self.replace: dict[UOp, UOp] = {} + self.rewrite_into_calls = rewrite_into_calls # no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx) @@ -1269,6 +1270,10 @@ class RewriteContext: if n in waitlist: stack.extend(waitlist.pop(n)) continue stack.append((n, 1, new_n)) + # NOTE: CALL is handled as a special case. + # The function that is called is not included in the graph_rewrite. + # If you want to graph_rewrite a call, you can + if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0] for x in reversed(new_n.src): if x in on_stack: continue stack.append((x, 0, x)) @@ -1307,8 +1312,8 @@ class RewriteContext: return self.replace[root] @profile_matches -def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp: - rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx) +def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp: + rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls) return rewrite_ctx.unified_rewrite(sink) def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype) @@ -1343,7 +1348,6 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get _remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo)) -pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))]) def do_unbind(ctx:dict[Variable, int], x:UOp): v,i = x.unbind()