mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
CALL with return value is FUNCTION (#15758)
* CALL with return value is FUNCTION (GPT try) * cleanups
This commit is contained in:
@@ -220,9 +220,9 @@ class TestCallSchedule(unittest.TestCase):
|
||||
a = Tensor.empty(4, 8)
|
||||
b = Tensor.empty(4, 8)
|
||||
r0, r1 = f(a), f(b)
|
||||
# find the CALL nodes
|
||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.CALL)
|
||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.CALL)
|
||||
# find the FUNCTION nodes
|
||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
|
||||
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
||||
|
||||
|
||||
@@ -96,8 +96,7 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}"
|
||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
|
||||
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
||||
|
||||
# add the outputs to the call
|
||||
@@ -107,20 +106,20 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
|
||||
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
|
||||
|
||||
# create the new thing for the big graph
|
||||
new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None)
|
||||
# body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION)
|
||||
new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg)
|
||||
rets = tuple(o.after(new_call) for o in outs)
|
||||
|
||||
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
||||
# NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes
|
||||
# NOTE: must use resolved shapes from the FUNCTION (which substitutes PARAMs with external args), not raw body shapes
|
||||
rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved))
|
||||
|
||||
return UOp.maketuple(*rets)
|
||||
|
||||
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# transform precompiled CALLs
|
||||
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
||||
# transform precompiled FUNCTIONs into CALLs (body becomes SINK with stores)
|
||||
(UPat(Ops.FUNCTION, name="c"), transform_precompiled_call),
|
||||
|
||||
# resolve TUPLE+GETTUPLE (for precompiled calls)
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
|
||||
@@ -93,17 +93,18 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
grads: dict[UOp, UOp] = {root: root_grad}
|
||||
for t0 in reversed(walk):
|
||||
if t0 not in grads or grads[t0].op is Ops.NOOP: continue
|
||||
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
||||
# GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION
|
||||
if t0.op is Ops.GETTUPLE:
|
||||
k = t0.src[0] # the CALL
|
||||
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
|
||||
k = t0.src[0] # the FUNCTION
|
||||
assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
|
||||
n_outputs = len(k.src[0].src)
|
||||
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
|
||||
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
|
||||
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
||||
continue
|
||||
# CALL: pass needed param set so backward only computes required gradients
|
||||
if t0.op is Ops.CALL:
|
||||
# FUNCTION/CALL: pass needed param set so backward only computes required gradients
|
||||
# (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn)
|
||||
if t0.op in {Ops.FUNCTION, Ops.CALL}:
|
||||
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
|
||||
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
|
||||
else:
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
@@ -164,7 +164,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
||||
if x.op in {Ops.CALL, Ops.FUNCTION, Ops.LINEAR}: continue
|
||||
|
||||
# only STORE+AFTER has range
|
||||
if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.schedule.allreduce import handle_allreduce
|
||||
|
||||
@@ -116,11 +116,11 @@ def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
def rewrite_into_call(call:UOp):
|
||||
if not should_resolve_call(call): return None
|
||||
def rewrite_into_function(call:UOp):
|
||||
if call.arg.precompile: return None
|
||||
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
|
||||
new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:])
|
||||
# after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard CALL, wrap each GETTUPLE in its own MULTI
|
||||
# after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard FUNCTION, wrap each GETTUPLE in its own MULTI
|
||||
assert new_body.op is Ops.TUPLE
|
||||
if any(s.op is Ops.MULTI for s in new_body.src):
|
||||
shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args)
|
||||
@@ -149,17 +149,16 @@ multi_pm = PatternMatcher([
|
||||
|
||||
# resolve TUPLE+GETTUPLE (needed in multi)
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
# GETTUPLE on MULTI: passthrough MULTI (e.g. when CALL was replaced by MULTI(GETTUPLE(...)))
|
||||
# GETTUPLE on MULTI: passthrough MULTI (e.g. when FUNCTION was replaced by MULTI(GETTUPLE(...)))
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.MULTI, name="multi"),), name="g"),
|
||||
lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.CALL, Ops.TUPLE}
|
||||
lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.FUNCTION, Ops.TUPLE}
|
||||
else multi),
|
||||
# rewrite into calls explicitly for MULTI
|
||||
(UPat(Ops.CALL, name="call"), rewrite_into_call),
|
||||
(UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
|
||||
# rewrite into FUNCTION calls explicitly for MULTI (value-producing)
|
||||
(UPat(Ops.FUNCTION, name="call"), rewrite_into_function),
|
||||
(UPat((Ops.CALL, Ops.FUNCTION, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# just strip the MULTI from non-value-producing CALLs (custom kernels, etc.) — FUNCTION is handled by rewrite_into_function
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
|
||||
if root.src[0].op is not Ops.TUPLE else None),
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# remove MULTI from STORE
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
||||
@@ -126,8 +126,8 @@ mop_cleanup = PatternMatcher([
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
if not should_resolve_call(c): return None
|
||||
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
if c.arg.precompile: return None
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
@@ -150,8 +150,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
||||
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
||||
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
# resolve FUNCTION calls (inline the body)
|
||||
(UPat(Ops.FUNCTION, name="c"), resolve_function),
|
||||
|
||||
# resolve TUPLE+GETTUPLE
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
|
||||
@@ -26,7 +26,8 @@ class Ops(FastEnum):
|
||||
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
PARAM = auto(); CALL = auto()
|
||||
# FUNCTION has a TUPLE body and is gradient-able; CALL is an opaque kernel invocation
|
||||
PARAM = auto(); FUNCTION = auto(); CALL = auto()
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||
|
||||
@@ -26,7 +26,8 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||
Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
|
||||
@@ -178,7 +179,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if not visited:
|
||||
if gate is None or gate(node):
|
||||
stack.append((node, True)) # push node back on stack to process after its srcs
|
||||
for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]):
|
||||
for s in reversed(node.src if enter_calls or node.op not in {Ops.CALL, Ops.FUNCTION} else node.src[1:]):
|
||||
stack.append((s, False)) # push srcs on the stack
|
||||
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return cache
|
||||
@@ -215,17 +216,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE:
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
|
||||
case Ops.GETTUPLE:
|
||||
# GETTUPLE extracts from a TUPLE (possibly through a CALL)
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
||||
# GETTUPLE extracts from a TUPLE (possibly through a FUNCTION)
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
|
||||
assert in_tuple.op is Ops.TUPLE
|
||||
inner_shape = in_tuple.src[self.arg]._shape
|
||||
if inner_shape is None: return None
|
||||
# if through a CALL, substitute internal PARAMs in the shape with corresponding args
|
||||
if self.src[0].op is Ops.CALL:
|
||||
# if through a FUNCTION, substitute internal PARAMs in the shape with corresponding args
|
||||
if self.src[0].op is Ops.FUNCTION:
|
||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
return inner_shape
|
||||
|
||||
@@ -262,8 +263,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
return self.src[0]._shape
|
||||
|
||||
case Ops.CALL: return None
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
ps = self.src[0]._shape
|
||||
@@ -418,8 +417,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.TUPLE, dtypes.void, srcs)
|
||||
def gettuple(self, idx:int) -> UOp:
|
||||
in_tuple = self.src[0] if self.op is Ops.CALL else self
|
||||
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}"
|
||||
in_tuple = self.src[0] if self.op is Ops.FUNCTION else self
|
||||
assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}"
|
||||
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
|
||||
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
@@ -554,7 +553,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# GETTUPLE: axis comes from the specific TUPLE element, not src[0]
|
||||
if self.op is Ops.GETTUPLE:
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
|
||||
return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None
|
||||
# PARAM: axis is stored as a MULTI source
|
||||
if self.op is Ops.PARAM:
|
||||
@@ -931,13 +930,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
|
||||
return p
|
||||
|
||||
_NO_TUPLE_WRAP = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION, Ops.TUPLE}
|
||||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION}
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
|
||||
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
|
||||
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
if self.op in UOp._OPAQUE_CALL_BODIES:
|
||||
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
# value-producing bodies are always wrapped in TUPLE so FUNCTION dtype is always void
|
||||
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
|
||||
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
@@ -968,13 +970,6 @@ class CallInfo:
|
||||
gf = id(self.grad_fxn) if self.grad_fxn else None
|
||||
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"
|
||||
|
||||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False
|
||||
if c.arg.precompile: return False
|
||||
return True
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
def safe_exp2(x):
|
||||
@@ -1357,7 +1352,7 @@ class RewriteContext:
|
||||
continue
|
||||
# no rewrite, process children then come back to rebuild
|
||||
stack.append((n, True))
|
||||
if not self.enter_calls and n.op is Ops.CALL: self.replace[n.src[0]] = n.src[0]
|
||||
if not self.enter_calls and n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[n.src[0]] = n.src[0]
|
||||
for x in reversed(n.src):
|
||||
if x not in self.replace: stack.append((x, False))
|
||||
else:
|
||||
@@ -1394,10 +1389,10 @@ class RewriteContext:
|
||||
if n in waitlist: stack.extend(waitlist.pop(n))
|
||||
continue
|
||||
stack.append((n, 1, new_n))
|
||||
# NOTE: CALL is handled as a special case.
|
||||
# NOTE: CALL/FUNCTION are handled as a special case.
|
||||
# The function that is called is not included in the graph_rewrite.
|
||||
# If you want to graph_rewrite a call, you can
|
||||
if not self.enter_calls and new_n.op is Ops.CALL: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
if not self.enter_calls and new_n.op in {Ops.CALL, Ops.FUNCTION}: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
@@ -1618,7 +1613,7 @@ def pyrender(ast:UOp) -> str:
|
||||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END}
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.FUNCTION, Ops.WHERE, Ops.END}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
@@ -1626,7 +1621,7 @@ def pyrender(ast:UOp) -> str:
|
||||
for s in u.src: to_render.add(s)
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
|
||||
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
|
||||
if u.op in {Ops.CALL, Ops.FUNCTION}: raise NotImplementedError("call can't be pyrendered")
|
||||
if u.op in not_rendered: continue
|
||||
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
|
||||
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
|
||||
|
||||
@@ -95,7 +95,7 @@ _tensor_spec = PatternMatcher([
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
(UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
@@ -132,14 +132,16 @@ _tensor_spec = PatternMatcher([
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
|
||||
# allow CALL/FUNCTION/PARAM/CUSTOM_FUNCTION — both CALL and FUNCTION dtype is always void
|
||||
# FUNCTION must have a TUPLE body in src[0] (invariant enforced by UOp.call); CALL bodies are opaque
|
||||
(UPat(Ops.CALL, dtypes.void), lambda: True),
|
||||
(UPat(Ops.FUNCTION, dtypes.void, src=(UPat(Ops.TUPLE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
||||
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
|
||||
# TUPLE must have void dtype, GETTUPLE can only appear on FUNCTION or TUPLE
|
||||
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
|
||||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||
|
||||
# ** for custom kernels **
|
||||
|
||||
@@ -274,7 +276,7 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
(UPat((Ops.CALL, Ops.FUNCTION), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# where on index in rhs position is fine
|
||||
(UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
|
||||
|
||||
@@ -285,7 +285,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||
else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
|
||||
@@ -54,7 +54,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
||||
width = Math.max(width, ctx.measureText(line).width);
|
||||
height += lineHeight;
|
||||
}
|
||||
const callNode = label.startsWith("CALL\n");
|
||||
const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n");
|
||||
if (callNode) callCount++;
|
||||
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode});
|
||||
// add edges
|
||||
|
||||
@@ -50,7 +50,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
Ops.LINEAR: "#7DF4FF",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
@@ -136,7 +137,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||
label += f"\n({multirange_str(rngs, color=True)})"
|
||||
if u._shape is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op is Ops.CALL:
|
||||
if u.op in {Ops.CALL, Ops.FUNCTION}:
|
||||
label += f"\n{u.src[0].key.hex()[:8]}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
||||
@@ -147,9 +148,9 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
|
||||
if (ref:=data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op not in {Ops.CALL, Ops.FUNCTION}: label += "\n"+str(u.metadata)
|
||||
# limit SOURCE labels line count
|
||||
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
|
||||
label = "\n".join(lines[:30]) + "\n..."
|
||||
|
||||
Reference in New Issue
Block a user