From d24466c84425becfae700e39206a60cfc52ce5ea Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:25:07 +0800 Subject: [PATCH] CALL with return value is FUNCTION (#15758) * CALL with return value is FUNCTION (GPT try) * cleanups --- test/unit/test_call.py | 6 ++--- tinygrad/callify.py | 13 +++++---- tinygrad/gradient.py | 11 ++++---- tinygrad/schedule/indexing.py | 4 +-- tinygrad/schedule/multi.py | 23 ++++++++-------- tinygrad/schedule/rangeify.py | 10 +++---- tinygrad/uop/__init__.py | 3 ++- tinygrad/uop/ops.py | 51 ++++++++++++++++------------------- tinygrad/uop/spec.py | 12 +++++---- tinygrad/uop/symbolic.py | 2 +- tinygrad/viz/js/worker.js | 2 +- tinygrad/viz/serve.py | 9 ++++--- 12 files changed, 72 insertions(+), 74 deletions(-) diff --git a/test/unit/test_call.py b/test/unit/test_call.py index a889be4a4e..89164c3424 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -220,9 +220,9 @@ class TestCallSchedule(unittest.TestCase): a = Tensor.empty(4, 8) b = Tensor.empty(4, 8) r0, r1 = f(a), f(b) - # find the CALL nodes - c0 = next(u for u in r0.uop.toposort() if u.op is Ops.CALL) - c1 = next(u for u in r1.uop.toposort() if u.op is Ops.CALL) + # find the FUNCTION nodes + c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION) + c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION) # the function bodies (src[0]) should have identical keys — unique consts must not leak through self.assertEqual(c0.src[0].key, c1.src[0].key) diff --git a/tinygrad/callify.py b/tinygrad/callify.py index f0231ba4d0..6b34ae4674 100644 --- a/tinygrad/callify.py +++ b/tinygrad/callify.py @@ -96,8 +96,7 @@ def contiguous_mops_to_view(c:UOp, src:UOp): def transform_precompiled_call(c:UOp) -> UOp|None: if not c.arg.precompile: return None - if c.src[0].op is Ops.SINK: return None - assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}" + assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}" input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:]) # add the outputs to the call @@ -107,20 +106,20 @@ def transform_precompiled_call(c:UOp) -> UOp|None: targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))] fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)]) - # create the new thing for the big graph - new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None) + # body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION) + new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg) rets = tuple(o.after(new_call) for o in outs) # if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape - # NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes + # NOTE: must use resolved shapes from the FUNCTION (which substitutes PARAMs with external args), not raw body shapes rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved)) return UOp.maketuple(*rets) # NOTE: adding rules to here is bad. these all need to run before the schedule cache pm_early_transform_tensor_graph = PatternMatcher([ - # transform precompiled CALLs - (UPat(Ops.CALL, name="c"), transform_precompiled_call), + # transform precompiled FUNCTIONs into CALLs (body becomes SINK with stores) + (UPat(Ops.FUNCTION, name="c"), transform_precompiled_call), # resolve TUPLE+GETTUPLE (for precompiled calls) (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index f505697a93..9de7db1217 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -93,17 +93,18 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp grads: dict[UOp, UOp] = {root: root_grad} for t0 in reversed(walk): if t0 not in grads or grads[t0].op is Ops.NOOP: continue - # GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL + # GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION if t0.op is Ops.GETTUPLE: - k = t0.src[0] # the CALL - assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE + k = t0.src[0] # the FUNCTION + assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE n_outputs = len(k.src[0].src) prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs)) grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs))) continue - # CALL: pass needed param set so backward only computes required gradients - if t0.op is Ops.CALL: + # FUNCTION/CALL: pass needed param set so backward only computes required gradients + # (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn) + if t0.op in {Ops.FUNCTION, Ops.CALL}: needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)} lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed) else: diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 7d5635de37..25aa9551d2 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, - Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL} + Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION} def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None @@ -164,7 +164,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue # no ranges on kernels, they are internal - if x.op in {Ops.CALL, Ops.LINEAR}: continue + if x.op in {Ops.CALL, Ops.FUNCTION, Ops.LINEAR}: continue # only STORE+AFTER has range if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 00e502ea76..fde893bfe7 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,5 +1,5 @@ from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST -from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call +from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite from tinygrad.dtype import dtypes from tinygrad.schedule.allreduce import handle_allreduce @@ -116,11 +116,11 @@ def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0 def passthrough_multi(root:UOp, multi:UOp): return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) -def rewrite_into_call(call:UOp): - if not should_resolve_call(call): return None +def rewrite_into_function(call:UOp): + if call.arg.precompile: return None new_body = graph_rewrite(call.src[0], multi_pm, name="subcall") new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:]) - # after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard CALL, wrap each GETTUPLE in its own MULTI + # after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard FUNCTION, wrap each GETTUPLE in its own MULTI assert new_body.op is Ops.TUPLE if any(s.op is Ops.MULTI for s in new_body.src): shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args) @@ -149,17 +149,16 @@ multi_pm = PatternMatcher([ # resolve TUPLE+GETTUPLE (needed in multi) (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), - # GETTUPLE on MULTI: passthrough MULTI (e.g. when CALL was replaced by MULTI(GETTUPLE(...))) + # GETTUPLE on MULTI: passthrough MULTI (e.g. when FUNCTION was replaced by MULTI(GETTUPLE(...))) (UPat(Ops.GETTUPLE, src=(UPat(Ops.MULTI, name="multi"),), name="g"), - lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.CALL, Ops.TUPLE} + lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.FUNCTION, Ops.TUPLE} else multi), - # rewrite into calls explicitly for MULTI - (UPat(Ops.CALL, name="call"), rewrite_into_call), - (UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), - # we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call + # rewrite into FUNCTION calls explicitly for MULTI (value-producing) + (UPat(Ops.FUNCTION, name="call"), rewrite_into_function), + (UPat((Ops.CALL, Ops.FUNCTION, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), + # just strip the MULTI from non-value-producing CALLs (custom kernels, etc.) — FUNCTION is handled by rewrite_into_function (UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root: - UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg) - if root.src[0].op is not Ops.TUPLE else None), + UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), # remove MULTI from STORE diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 92afbf8aa6..84af8c1d0d 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element +from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element @@ -126,8 +126,8 @@ mop_cleanup = PatternMatcher([ ]) pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) -def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: - if not should_resolve_call(c): return None +def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None: + if c.arg.precompile: return None params: list[UOp] = [] graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params") params = sorted(params, key=lambda x: x.arg) @@ -150,8 +150,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ (UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))), lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None), - # resolve calls - (UPat(Ops.CALL, name="c"), resolve_call), + # resolve FUNCTION calls (inline the body) + (UPat(Ops.FUNCTION, name="c"), resolve_function), # resolve TUPLE+GETTUPLE (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index aaabfe3756..bbc7dc0992 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -26,7 +26,8 @@ class Ops(FastEnum): # uops that aren't rendered NOOP = auto(); REWRITE_ERROR = auto() - PARAM = auto(); CALL = auto() + # FUNCTION has a TUPLE body and is gradient-able; CALL is an opaque kernel invocation + PARAM = auto(); FUNCTION = auto(); CALL = auto() # renderer # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 978f55cbef..dbb0f1657a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -26,7 +26,8 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} -range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1} +range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1, + Ops.COPY: 2, Ops.BUFFER_VIEW: 1} # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op]) @@ -178,7 +179,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if not visited: if gate is None or gate(node): stack.append((node, True)) # push node back on stack to process after its srcs - for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]): + for s in reversed(node.src if enter_calls or node.op not in {Ops.CALL, Ops.FUNCTION} else node.src[1:]): stack.append((s, False)) # push srcs on the stack else: cache[node] = None # second time i'm seeing this node, add it to returned toposort return cache @@ -215,17 +216,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ - Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE: + Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION: return None case Ops.GETTUPLE: - # GETTUPLE extracts from a TUPLE (possibly through a CALL) - in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0] + # GETTUPLE extracts from a TUPLE (possibly through a FUNCTION) + in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0] assert in_tuple.op is Ops.TUPLE inner_shape = in_tuple.src[self.arg]._shape if inner_shape is None: return None - # if through a CALL, substitute internal PARAMs in the shape with corresponding args - if self.src[0].op is Ops.CALL: + # if through a FUNCTION, substitute internal PARAMs in the shape with corresponding args + if self.src[0].op is Ops.FUNCTION: return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape) return inner_shape @@ -262,8 +263,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: return self.src[0]._shape - case Ops.CALL: return None - # TODO: disallow shape changing bitcast case Ops.BITCAST: ps = self.src[0]._shape @@ -418,8 +417,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def maketuple(*srcs:UOp): # pylint: disable=no-self-argument return UOp(Ops.TUPLE, dtypes.void, srcs) def gettuple(self, idx:int) -> UOp: - in_tuple = self.src[0] if self.op is Ops.CALL else self - assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}" + in_tuple = self.src[0] if self.op is Ops.FUNCTION else self + assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}" return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx) def group(*srcs:UOp|None): # pylint: disable=no-self-argument if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] @@ -554,7 +553,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.MULTI: return self.arg # GETTUPLE: axis comes from the specific TUPLE element, not src[0] if self.op is Ops.GETTUPLE: - in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0] + in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0] return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None # PARAM: axis is stored as a MULTI source if self.op is Ops.PARAM: @@ -931,13 +930,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),)) return p - _NO_TUPLE_WRAP = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION, Ops.TUPLE} + # opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE) + _OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION} def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp: assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}" - # value-producing bodies are always wrapped in TUPLE so CALL dtype is always void - body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self) - return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) + if self.op in UOp._OPAQUE_CALL_BODIES: + return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) + # value-producing bodies are always wrapped in TUPLE so FUNCTION dtype is always void + body = self if self.op is Ops.TUPLE else UOp.maketuple(self) + return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)] @@ -968,13 +970,6 @@ class CallInfo: gf = id(self.grad_fxn) if self.grad_fxn else None return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})" -def should_resolve_call(c:UOp) -> bool: - # don't resolve real kernel calls, sink or program - if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False - if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False - if c.arg.precompile: return False - return True - # ******** ops in python ******** def safe_exp2(x): @@ -1357,7 +1352,7 @@ class RewriteContext: continue # no rewrite, process children then come back to rebuild stack.append((n, True)) - if not self.enter_calls and n.op is Ops.CALL: self.replace[n.src[0]] = n.src[0] + if not self.enter_calls and n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[n.src[0]] = n.src[0] for x in reversed(n.src): if x not in self.replace: stack.append((x, False)) else: @@ -1394,10 +1389,10 @@ class RewriteContext: if n in waitlist: stack.extend(waitlist.pop(n)) continue stack.append((n, 1, new_n)) - # NOTE: CALL is handled as a special case. + # NOTE: CALL/FUNCTION are handled as a special case. # The function that is called is not included in the graph_rewrite. # If you want to graph_rewrite a call, you can - if not self.enter_calls and new_n.op is Ops.CALL: self.replace[new_n.src[0]] = new_n.src[0] + if not self.enter_calls and new_n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[new_n.src[0]] = new_n.src[0] for x in reversed(new_n.src): if x in on_stack: continue stack.append((x, 0, x)) @@ -1618,7 +1613,7 @@ def pyrender(ast:UOp) -> str: cmap = consumer_map_from_toposort(lst) not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE} always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE, - Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END} + Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.FUNCTION, Ops.WHERE, Ops.END} to_render: set[UOp] = {ast} for u in lst: @@ -1626,7 +1621,7 @@ def pyrender(ast:UOp) -> str: for s in u.src: to_render.add(s) if u.op is Ops.STORE: to_render.add(u.src[1]) if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0]) - if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered") + if u.op in {Ops.CALL, Ops.FUNCTION}: raise NotImplementedError("call can't be pyrendered") if u.op in not_rendered: continue # checking the consumers is not enough, you have to make sure it's not used twice by the one consumer if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index b9450f0f18..a29c97efc1 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -95,7 +95,7 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True), # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER - (UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), + (UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), # MSELECT chooses one of the multi buffers (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)), @@ -132,14 +132,16 @@ _tensor_spec = PatternMatcher([ # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), - # allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void + # allow CALL/FUNCTION/PARAM/CUSTOM_FUNCTION — both CALL and FUNCTION dtype is always void + # FUNCTION must have a TUPLE body in src[0] (invariant enforced by UOp.call); CALL bodies are opaque (UPat(Ops.CALL, dtypes.void), lambda: True), + (UPat(Ops.FUNCTION, dtypes.void, src=(UPat(Ops.TUPLE),), allow_any_len=True), lambda: True), (UPat(Ops.PARAM), lambda: True), (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)), - # TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE + # TUPLE must have void dtype, GETTUPLE can only appear on FUNCTION or TUPLE (UPat(Ops.TUPLE, dtypes.void), lambda: True), - (UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)), + (UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)), # ** for custom kernels ** @@ -274,7 +276,7 @@ full_spec = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True), # linearizer: outputs + intermediate KERNELs - (UPat(Ops.CALL, dtype=dtypes.void), lambda: True), + (UPat((Ops.CALL, Ops.FUNCTION), dtype=dtypes.void), lambda: True), # where on index in rhs position is fine (UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index c276b424b5..9f391ae9b5 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -285,7 +285,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), # only RANGE/IF/STORE/KERNEL have side effects (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ - tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE} + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE} else y.src for y in x.src[1:]])))), # after with 1 src is just src[0] (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 9e838dbf7b..645cca58b1 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -54,7 +54,7 @@ const layoutUOp = (g, { graph, change }, opts) => { width = Math.max(width, ctx.measureText(line).width); height += lineHeight; } - const callNode = label.startsWith("CALL\n"); + const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n"); if (callNode) callCount++; g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode}); // add edges diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a1c4098a83..81d265cd8d 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -50,7 +50,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6", - Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040", + Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040", + Ops.LINEAR: "#7DF4FF", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} @@ -136,7 +137,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: label += f"\n({multirange_str(rngs, color=True)})" if u._shape is not None: label += f"\n{shape_to_str(u.shape)}" - if u.op is Ops.CALL: + if u.op in {Ops.CALL, Ops.FUNCTION}: label += f"\n{u.src[0].key.hex()[:8]}" if u.op in {Ops.INDEX, Ops.BUFFERIZE}: if len(u.toposort()) < 30: label += f"\n{u.render()}" @@ -147,9 +148,9 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs]) except Exception: label += "\n" - if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}" + if (ref:=data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}" # NOTE: kernel already has metadata in arg - if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata) + if TRACEMETA >= 2 and u.metadata is not None and u.op not in {Ops.CALL, Ops.FUNCTION}: label += "\n"+str(u.metadata) # limit SOURCE labels line count if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40: label = "\n".join(lines[:30]) + "\n..."