mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
don't graph_rewrite into calls (#14931)
* don't graph_rewrite into calls * optional * pm_gate_kernel_sink removed
This commit is contained in:
@@ -2,9 +2,6 @@ from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewr
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort
|
||||
|
||||
# these are the only uops that can get replaced in the tensor graph
|
||||
from tinygrad.schedule.rangeify import pm_gate_kernel_sink
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None or x in ctx[1]: return None
|
||||
if x.tag is None and x.op is Ops.CALL:
|
||||
@@ -25,7 +22,7 @@ def apply_after(ctx, u):
|
||||
ctx[2][u] = u.src[0]
|
||||
|
||||
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
|
||||
add_tags = pm_gate_kernel_sink+PatternMatcher([
|
||||
add_tags = PatternMatcher([
|
||||
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
|
||||
(UPat(Ops.AFTER, name="u"), apply_after),
|
||||
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
|
||||
|
||||
@@ -116,7 +116,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
@@ -21,7 +21,7 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
if buf.base in x.backward_slice_with_self: ctx[x] = None
|
||||
|
||||
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
@@ -76,11 +76,14 @@ mop_cleanup = PatternMatcher([
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
@@ -486,9 +489,8 @@ def get_rangeify(sink:UOp) -> UOp:
|
||||
|
||||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
|
||||
name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, bottom_up=True, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
|
||||
|
||||
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
|
||||
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDT
|
||||
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.renderer import Estimates
|
||||
@@ -1229,12 +1229,13 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
SENTINEL: Final[UOp] = cast(UOp, object())
|
||||
class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.bpm_cache: dict[UOp, UOp|None] = {}
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.rewrite_into_calls = rewrite_into_calls
|
||||
|
||||
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
|
||||
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
|
||||
@@ -1269,6 +1270,10 @@ class RewriteContext:
|
||||
if n in waitlist: stack.extend(waitlist.pop(n))
|
||||
continue
|
||||
stack.append((n, 1, new_n))
|
||||
# NOTE: CALL is handled as a special case.
|
||||
# The function that is called is not included in the graph_rewrite.
|
||||
# If you want to graph_rewrite a call, you can
|
||||
if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
@@ -1307,8 +1312,8 @@ class RewriteContext:
|
||||
return self.replace[root]
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls)
|
||||
return rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
@@ -1343,7 +1348,6 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
|
||||
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
|
||||
Reference in New Issue
Block a user