CALL with return value is FUNCTION (#15758)

* CALL with return value is FUNCTION (GPT try)

* cleanups
This commit is contained in:
George Hotz
2026-04-16 13:25:07 +08:00
committed by GitHub
parent 218d6b8988
commit d24466c844
12 changed files with 72 additions and 74 deletions

View File

@@ -220,9 +220,9 @@ class TestCallSchedule(unittest.TestCase):
a = Tensor.empty(4, 8)
b = Tensor.empty(4, 8)
r0, r1 = f(a), f(b)
# find the CALL nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.CALL)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.CALL)
# find the FUNCTION nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
self.assertEqual(c0.src[0].key, c1.src[0].key)

View File

@@ -96,8 +96,7 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None
if c.src[0].op is Ops.SINK: return None
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}"
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
# add the outputs to the call
@@ -107,20 +106,20 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
# create the new thing for the big graph
new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None)
# body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION)
new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg)
rets = tuple(o.after(new_call) for o in outs)
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
# NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes
# NOTE: must use resolved shapes from the FUNCTION (which substitutes PARAMs with external args), not raw body shapes
rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved))
return UOp.maketuple(*rets)
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
pm_early_transform_tensor_graph = PatternMatcher([
# transform precompiled CALLs
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
# transform precompiled FUNCTIONs into CALLs (body becomes SINK with stores)
(UPat(Ops.FUNCTION, name="c"), transform_precompiled_call),
# resolve TUPLE+GETTUPLE (for precompiled calls)
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),

View File

@@ -93,17 +93,18 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
grads: dict[UOp, UOp] = {root: root_grad}
for t0 in reversed(walk):
if t0 not in grads or grads[t0].op is Ops.NOOP: continue
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
# GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION
if t0.op is Ops.GETTUPLE:
k = t0.src[0] # the CALL
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
k = t0.src[0] # the FUNCTION
assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
n_outputs = len(k.src[0].src)
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
continue
# CALL: pass needed param set so backward only computes required gradients
if t0.op is Ops.CALL:
# FUNCTION/CALL: pass needed param set so backward only computes required gradients
# (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn)
if t0.op in {Ops.FUNCTION, Ops.CALL}:
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
else:

View File

@@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
@@ -164,7 +164,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
# no ranges on kernels, they are internal
if x.op in {Ops.CALL, Ops.LINEAR}: continue
if x.op in {Ops.CALL, Ops.FUNCTION, Ops.LINEAR}: continue
# only STORE+AFTER has range
if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue

View File

@@ -1,5 +1,5 @@
from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite
from tinygrad.dtype import dtypes
from tinygrad.schedule.allreduce import handle_allreduce
@@ -116,11 +116,11 @@ def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0
def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
def rewrite_into_call(call:UOp):
if not should_resolve_call(call): return None
def rewrite_into_function(call:UOp):
if call.arg.precompile: return None
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:])
# after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard CALL, wrap each GETTUPLE in its own MULTI
# after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard FUNCTION, wrap each GETTUPLE in its own MULTI
assert new_body.op is Ops.TUPLE
if any(s.op is Ops.MULTI for s in new_body.src):
shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args)
@@ -149,17 +149,16 @@ multi_pm = PatternMatcher([
# resolve TUPLE+GETTUPLE (needed in multi)
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
# GETTUPLE on MULTI: passthrough MULTI (e.g. when CALL was replaced by MULTI(GETTUPLE(...)))
# GETTUPLE on MULTI: passthrough MULTI (e.g. when FUNCTION was replaced by MULTI(GETTUPLE(...)))
(UPat(Ops.GETTUPLE, src=(UPat(Ops.MULTI, name="multi"),), name="g"),
lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.CALL, Ops.TUPLE}
lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.FUNCTION, Ops.TUPLE}
else multi),
# rewrite into calls explicitly for MULTI
(UPat(Ops.CALL, name="call"), rewrite_into_call),
(UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
# rewrite into FUNCTION calls explicitly for MULTI (value-producing)
(UPat(Ops.FUNCTION, name="call"), rewrite_into_function),
(UPat((Ops.CALL, Ops.FUNCTION, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# just strip the MULTI from non-value-producing CALLs (custom kernels, etc.) — FUNCTION is handled by rewrite_into_function
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
if root.src[0].op is not Ops.TUPLE else None),
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# remove MULTI from STORE

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
@@ -126,8 +126,8 @@ mop_cleanup = PatternMatcher([
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
if not should_resolve_call(c): return None
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if c.arg.precompile: return None
params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
params = sorted(params, key=lambda x: x.arg)
@@ -150,8 +150,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve FUNCTION calls (inline the body)
(UPat(Ops.FUNCTION, name="c"), resolve_function),
# resolve TUPLE+GETTUPLE
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),

View File

@@ -26,7 +26,8 @@ class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto()
PARAM = auto(); CALL = auto()
# FUNCTION has a TUPLE body and is gradient-able; CALL is an opaque kernel invocation
PARAM = auto(); FUNCTION = auto(); CALL = auto()
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled

View File

@@ -26,7 +26,8 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
@@ -178,7 +179,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if not visited:
if gate is None or gate(node):
stack.append((node, True)) # push node back on stack to process after its srcs
for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]):
for s in reversed(node.src if enter_calls or node.op not in {Ops.CALL, Ops.FUNCTION} else node.src[1:]):
stack.append((s, False)) # push srcs on the stack
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
return cache
@@ -215,17 +216,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE:
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
case Ops.GETTUPLE:
# GETTUPLE extracts from a TUPLE (possibly through a CALL)
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
# GETTUPLE extracts from a TUPLE (possibly through a FUNCTION)
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
assert in_tuple.op is Ops.TUPLE
inner_shape = in_tuple.src[self.arg]._shape
if inner_shape is None: return None
# if through a CALL, substitute internal PARAMs in the shape with corresponding args
if self.src[0].op is Ops.CALL:
# if through a FUNCTION, substitute internal PARAMs in the shape with corresponding args
if self.src[0].op is Ops.FUNCTION:
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
return inner_shape
@@ -262,8 +263,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape
case Ops.CALL: return None
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
ps = self.src[0]._shape
@@ -418,8 +417,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
return UOp(Ops.TUPLE, dtypes.void, srcs)
def gettuple(self, idx:int) -> UOp:
in_tuple = self.src[0] if self.op is Ops.CALL else self
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}"
in_tuple = self.src[0] if self.op is Ops.FUNCTION else self
assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}"
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
@@ -554,7 +553,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.MULTI: return self.arg
# GETTUPLE: axis comes from the specific TUPLE element, not src[0]
if self.op is Ops.GETTUPLE:
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None
# PARAM: axis is stored as a MULTI source
if self.op is Ops.PARAM:
@@ -931,13 +930,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
return p
_NO_TUPLE_WRAP = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION, Ops.TUPLE}
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION}
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
if self.op in UOp._OPAQUE_CALL_BODIES:
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
# value-producing bodies are always wrapped in TUPLE so FUNCTION dtype is always void
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
@@ -968,13 +970,6 @@ class CallInfo:
gf = id(self.grad_fxn) if self.grad_fxn else None
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"
def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False
if c.arg.precompile: return False
return True
# ******** ops in python ********
def safe_exp2(x):
@@ -1357,7 +1352,7 @@ class RewriteContext:
continue
# no rewrite, process children then come back to rebuild
stack.append((n, True))
if not self.enter_calls and n.op is Ops.CALL: self.replace[n.src[0]] = n.src[0]
if not self.enter_calls and n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[n.src[0]] = n.src[0]
for x in reversed(n.src):
if x not in self.replace: stack.append((x, False))
else:
@@ -1394,10 +1389,10 @@ class RewriteContext:
if n in waitlist: stack.extend(waitlist.pop(n))
continue
stack.append((n, 1, new_n))
# NOTE: CALL is handled as a special case.
# NOTE: CALL/FUNCTION are handled as a special case.
# The function that is called is not included in the graph_rewrite.
# If you want to graph_rewrite a call, you can
if not self.enter_calls and new_n.op is Ops.CALL: self.replace[new_n.src[0]] = new_n.src[0]
if not self.enter_calls and new_n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[new_n.src[0]] = new_n.src[0]
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
@@ -1618,7 +1613,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END}
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.FUNCTION, Ops.WHERE, Ops.END}
to_render: set[UOp] = {ast}
for u in lst:
@@ -1626,7 +1621,7 @@ def pyrender(ast:UOp) -> str:
for s in u.src: to_render.add(s)
if u.op is Ops.STORE: to_render.add(u.src[1])
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
if u.op in {Ops.CALL, Ops.FUNCTION}: raise NotImplementedError("call can't be pyrendered")
if u.op in not_rendered: continue
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue

View File

@@ -95,7 +95,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
(UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
@@ -132,14 +132,16 @@ _tensor_spec = PatternMatcher([
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
# allow CALL/FUNCTION/PARAM/CUSTOM_FUNCTION — both CALL and FUNCTION dtype is always void
# FUNCTION must have a TUPLE body in src[0] (invariant enforced by UOp.call); CALL bodies are opaque
(UPat(Ops.CALL, dtypes.void), lambda: True),
(UPat(Ops.FUNCTION, dtypes.void, src=(UPat(Ops.TUPLE),), allow_any_len=True), lambda: True),
(UPat(Ops.PARAM), lambda: True),
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
# TUPLE must have void dtype, GETTUPLE can only appear on FUNCTION or TUPLE
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
# ** for custom kernels **
@@ -274,7 +276,7 @@ full_spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
(UPat((Ops.CALL, Ops.FUNCTION), dtype=dtypes.void), lambda: True),
# where on index in rhs position is fine
(UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True),

View File

@@ -285,7 +285,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),

View File

@@ -54,7 +54,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
width = Math.max(width, ctx.measureText(line).width);
height += lineHeight;
}
const callNode = label.startsWith("CALL\n");
const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n");
if (callNode) callCount++;
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode});
// add edges

View File

@@ -50,7 +50,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
Ops.LINEAR: "#7DF4FF",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
@@ -136,7 +137,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
label += f"\n({multirange_str(rngs, color=True)})"
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op is Ops.CALL:
if u.op in {Ops.CALL, Ops.FUNCTION}:
label += f"\n{u.src[0].key.hex()[:8]}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
if len(u.toposort()) < 30: label += f"\n{u.render()}"
@@ -147,9 +148,9 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
if (ref:=data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
if TRACEMETA >= 2 and u.metadata is not None and u.op not in {Ops.CALL, Ops.FUNCTION}: label += "\n"+str(u.metadata)
# limit SOURCE labels line count
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
label = "\n".join(lines[:30]) + "\n..."