mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add call/param UOps (#14433)
* add call/param UOps * resolve call * skip that for now * grad on call * fix tests
This commit is contained in:
54
test/unit/test_call.py
Normal file
54
test/unit/test_call.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
class TestCall(unittest.TestCase):
|
||||
def test_call_plus(self):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
Tensor.realize(a,b)
|
||||
|
||||
# we define a plus function
|
||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||
|
||||
c = Tensor.call(a, b, fxn=plus_fxn)
|
||||
np.testing.assert_equal(c.numpy(), (a+b).numpy())
|
||||
|
||||
def test_call_plus_backward(self):
|
||||
a = Tensor.ones(10, 10, requires_grad=True)
|
||||
b = Tensor.ones(10, 10, requires_grad=True)
|
||||
|
||||
(a+b).mean().backward()
|
||||
gt_a_grad = a.grad.numpy()
|
||||
gt_b_grad = b.grad.numpy()
|
||||
a.grad, b.grad = None, None
|
||||
|
||||
# this is the gradient for +
|
||||
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
|
||||
|
||||
# we define a plus function
|
||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||
c = Tensor.call(a, b, fxn=plus_fxn, arg=grad_fxn)
|
||||
c.mean().backward()
|
||||
|
||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
||||
|
||||
@unittest.skip("needs GEMM on mixins")
|
||||
def test_call_gemm(self):
|
||||
M, K, N = 4, 8, 4
|
||||
a = Tensor.randn(M, K)
|
||||
b = Tensor.randn(K, N)
|
||||
Tensor.realize(a, b)
|
||||
|
||||
# we define a gemm function
|
||||
x = UOp.param(0, dtypes.float, shape=(M, K))
|
||||
y = UOp.param(1, dtypes.float, shape=(K, N))
|
||||
c = Tensor.call(a, b, fxn=x@y)
|
||||
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -44,6 +44,8 @@ pm_gradient = PatternMatcher([
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# gradient on CALL is a custom function
|
||||
(UPat(Ops.CALL, name="k"), lambda ctx, k: (None,)+k.arg(ctx, k)),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
])
|
||||
|
||||
@@ -68,10 +68,23 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
|
||||
for i, (p, a) in enumerate(zip(params, args)):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict(zip(params, args)))
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
|
||||
|
||||
|
||||
@@ -232,6 +232,9 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
def call(self, *lst:Tensor, fxn:UOp, arg:Any=None) -> Tensor:
|
||||
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn, arg=arg))
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
||||
|
||||
@@ -26,6 +26,7 @@ class Ops(FastEnum):
|
||||
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
PARAM = auto(); CALL = auto()
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||
|
||||
@@ -220,9 +220,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.ENCDEC: return self.arg[0]
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.PARAM:
|
||||
# NOTE: copied from marg
|
||||
if len(self.src) == 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
return None
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
@@ -820,6 +824,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||
return self.src[0].after(self.store(val).end(*argfix(end)))
|
||||
|
||||
# TODO: this should replace placeholder
|
||||
@staticmethod
|
||||
def param(slot:int, dtype:DType, shape:tuple[int, ...]|None=None):
|
||||
src = () if shape is None else (UOp.const(dtypes.index.vec(len(shape)), shape),)
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
|
||||
@@ -137,7 +137,11 @@ _tensor_spec = PatternMatcher([
|
||||
|
||||
# Tensor range bind / store
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
|
||||
|
||||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
|
||||
@@ -48,6 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user