From 7a9dee4e507db51400ce31b2b206e4c744ebb9bc Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:51:45 +0800 Subject: [PATCH] add call/param UOps (#14433) * add call/param UOps * resolve call * skip that for now * grad on call * fix tests --- test/unit/test_call.py | 54 +++++++++++++++++++++++++++++++++++ tinygrad/gradient.py | 2 ++ tinygrad/schedule/rangeify.py | 13 +++++++++ tinygrad/tensor.py | 3 ++ tinygrad/uop/__init__.py | 1 + tinygrad/uop/ops.py | 13 ++++++++- tinygrad/uop/spec.py | 6 +++- tinygrad/viz/serve.py | 1 + 8 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 test/unit/test_call.py diff --git a/test/unit/test_call.py b/test/unit/test_call.py new file mode 100644 index 0000000000..c0321f4044 --- /dev/null +++ b/test/unit/test_call.py @@ -0,0 +1,54 @@ +import unittest +import numpy as np +from tinygrad import Tensor +from tinygrad.dtype import dtypes +from tinygrad.uop.ops import UOp + +class TestCall(unittest.TestCase): + def test_call_plus(self): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + Tensor.realize(a,b) + + # we define a plus function + plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10)) + + c = Tensor.call(a, b, fxn=plus_fxn) + np.testing.assert_equal(c.numpy(), (a+b).numpy()) + + def test_call_plus_backward(self): + a = Tensor.ones(10, 10, requires_grad=True) + b = Tensor.ones(10, 10, requires_grad=True) + + (a+b).mean().backward() + gt_a_grad = a.grad.numpy() + gt_b_grad = b.grad.numpy() + a.grad, b.grad = None, None + + # this is the gradient for + + def grad_fxn(grad:UOp, call:UOp): return (grad, grad) + + # we define a plus function + plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10)) + c = Tensor.call(a, b, fxn=plus_fxn, arg=grad_fxn) + c.mean().backward() + + np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) + np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) + + @unittest.skip("needs GEMM on mixins") + def test_call_gemm(self): + M, K, N = 4, 8, 4 + a = Tensor.randn(M, K) + b = Tensor.randn(K, N) + Tensor.realize(a, b) + + # we define a gemm function + x = UOp.param(0, dtypes.float, shape=(M, K)) + y = UOp.param(1, dtypes.float, shape=(K, N)) + c = Tensor.call(a, b, fxn=x@y) + + np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index a1d44c6ce2..86bc99696c 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -44,6 +44,8 @@ pm_gradient = PatternMatcher([ # NOTE: this is only correct when the KERNEL has a single output (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)), (UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)), + # gradient on CALL is a custom function + (UPat(Ops.CALL, name="k"), lambda ctx, k: (None,)+k.arg(ctx, k)), # there's no gradient for bitcast (UPat(Ops.BITCAST), lambda: (None,)), ]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9fc2204ae7..380370bb22 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -68,10 +68,23 @@ def resolve_custom_kernel(ck:UOp) -> UOp: placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)] return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders))) +def resolve_call(c:UOp) -> UOp: + params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) + args = c.src[1:] + if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}") + if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}") + for i, (p, a) in enumerate(zip(params, args)): + if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}") + if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}") + return c.src[0].substitute(dict(zip(params, args))) + earliest_rewrites = mop_cleanup+PatternMatcher([ # just removing it works... (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # resolve calls + (UPat(Ops.CALL, name="c"), resolve_call), + # resolve custom kernels (UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 430a7f04a2..1e5a897a7e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -232,6 +232,9 @@ class Tensor(OpMixin): # ***** data handlers **** + def call(self, *lst:Tensor, fxn:UOp, arg:Any=None) -> Tensor: + return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn, arg=arg)) + def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: """ Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied. diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index ad6776a067..9641fd6382 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -26,6 +26,7 @@ class Ops(FastEnum): # uops that aren't rendered NOOP = auto(); REWRITE_ERROR = auto() + PARAM = auto(); CALL = auto() # renderer # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 56d01b1ad6..7aa0820469 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -220,9 +220,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.ENCDEC: return self.arg[0] case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]]) case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) + case Ops.PARAM: + # NOTE: copied from marg + if len(self.src) == 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) + return None # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL: return self.src[0]._shape # ops with custom handling @@ -820,6 +824,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp: return self.src[0].after(self.store(val).end(*argfix(end))) + # TODO: this should replace placeholder + @staticmethod + def param(slot:int, dtype:DType, shape:tuple[int, ...]|None=None): + src = () if shape is None else (UOp.const(dtypes.index.vec(len(shape)), shape),) + return UOp(Ops.PARAM, dtype, src, arg=slot) + + def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn)) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 1764ec4ef4..ce69e43718 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -137,7 +137,11 @@ _tensor_spec = PatternMatcher([ # Tensor range bind / store (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True), - (UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True) + (UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True), + + # allow CALL/PARAM + (UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype), + (UPat(Ops.PARAM), lambda: True), ])+movement_ops+shared_spec tensor_spec = PatternMatcher([ diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index ab06fd0be2..be211c981f 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -48,6 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6", + Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}