add call/param UOps (#14433)

* add call/param UOps

* resolve call

* skip that for now

* grad on call

* fix tests
This commit is contained in:
George Hotz
2026-01-30 14:51:45 +08:00
committed by GitHub
parent 66d6a68016
commit 7a9dee4e50
8 changed files with 91 additions and 2 deletions

54
test/unit/test_call.py Normal file
View File

@@ -0,0 +1,54 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp
class TestCall(unittest.TestCase):
def test_call_plus(self):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
Tensor.realize(a,b)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
np.testing.assert_equal(c.numpy(), (a+b).numpy())
def test_call_plus_backward(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
# this is the gradient for +
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn, arg=grad_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
# we define a gemm function
x = UOp.param(0, dtypes.float, shape=(M, K))
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -44,6 +44,8 @@ pm_gradient = PatternMatcher([
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# gradient on CALL is a custom function
(UPat(Ops.CALL, name="k"), lambda ctx, k: (None,)+k.arg(ctx, k)),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)),
])

View File

@@ -68,10 +68,23 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
def resolve_call(c:UOp) -> UOp:
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
for i, (p, a) in enumerate(zip(params, args)):
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict(zip(params, args)))
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),

View File

@@ -232,6 +232,9 @@ class Tensor(OpMixin):
# ***** data handlers ****
def call(self, *lst:Tensor, fxn:UOp, arg:Any=None) -> Tensor:
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn, arg=arg))
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.

View File

@@ -26,6 +26,7 @@ class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto()
PARAM = auto(); CALL = auto()
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled

View File

@@ -220,9 +220,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.ENCDEC: return self.arg[0]
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
case Ops.PARAM:
# NOTE: copied from marg
if len(self.src) == 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
return None
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
return self.src[0]._shape
# ops with custom handling
@@ -820,6 +824,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
return self.src[0].after(self.store(val).end(*argfix(end)))
# TODO: this should replace placeholder
@staticmethod
def param(slot:int, dtype:DType, shape:tuple[int, ...]|None=None):
src = () if shape is None else (UOp.const(dtypes.index.vec(len(shape)), shape),)
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))

View File

@@ -137,7 +137,11 @@ _tensor_spec = PatternMatcher([
# Tensor range bind / store
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),
])+movement_ops+shared_spec
tensor_spec = PatternMatcher([

View File

@@ -48,6 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}