From 476276f4b4fae1d106ecdd22f434b76f4e1d38f0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:39:34 +0800 Subject: [PATCH] support grads on tuples (#15287) * support grads on tuples * simpler * grad_fxn works * cleanups * unused --- test/unit/test_function.py | 26 ++++++++++++++++++++++++++ tinygrad/function.py | 17 ++++++++++------- tinygrad/gradient.py | 26 ++++++++++++++++++-------- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/test/unit/test_function.py b/test/unit/test_function.py index df540852d1..5f0b4b4c1f 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -347,5 +347,31 @@ class TestFunctionTuple(unittest.TestCase): assert t2.tolist() == [3,3,3] def test_tuple_precompile(self): self.test_tuple(True) + def test_grad_tuple(self, precompile=False): + x = Tensor.ones(3, requires_grad=True).contiguous() + y = Tensor.ones(3, requires_grad=True).contiguous() + @function(precompile=precompile) + def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2) + t1, t2 = f(x,y) + (t1+t2).sum().backward() + x.grad.realize(y.grad) + def test_grad_tuple_precompile(self): self.test_grad_tuple(True) + + def test_grad_fxn_tuple(self): + # grad_fxn for tuple: ctx is a TUPLE UOp with one element per output + def grad_fxn(ctx:UOp, call:UOp): + # f(u1, u2) = (u1+1, u2+2), ctx.src = (d_out0, d_out1) + # df/du1 = d_out0, df/du2 = d_out1 + return (ctx.src[0], ctx.src[1]) + + x = Tensor.ones(3, requires_grad=True).contiguous() + y = Tensor.ones(3, requires_grad=True).contiguous() + @function(grad_fxn=grad_fxn) + def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2) + t1, t2 = f(x, y) + (t1+t2).sum().backward() + np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.]) + np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/function.py b/tinygrad/function.py index 936a445933..79d9743288 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -20,10 +20,11 @@ pm_ctx = PatternMatcher([ ReturnType = TypeVar('ReturnType') class _function(Generic[ReturnType]): - def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True): + def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None): self.fxn = fxn self.precompile = precompile self.allow_implicit = allow_implicit + self.grad_fxn = grad_fxn def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self @@ -69,7 +70,7 @@ class _function(Generic[ReturnType]): #call = assigned.call(*call_uops, buffer, name=name) #ret = buffer.after(call) - fret = uret.call(*call_uops, name=name, precompile=self.precompile) + fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile) if isinstance(ret, tuple): return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret)))) else: @@ -77,9 +78,11 @@ class _function(Generic[ReturnType]): # overload signatures support both @function and @function(precompile=True) syntax @overload -def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ... +def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, + grad_fxn:Callable|None=None) -> _function[ReturnType]: ... @overload -def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ... -def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True): - if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit) - return _function(fxn, precompile=precompile, allow_implicit=allow_implicit) +def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True, + grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ... +def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None): + if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn) + return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index af975bd81f..6772b391fa 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -15,16 +15,21 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]: if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k) - # auto-differentiate the function fxn, args = k.src[0], k.src[1:] params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} - grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values())) + if fxn.op is Ops.TUPLE: + grad_args = ctx.src + root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args))) + else: + grad_args = (ctx,) + root_grad = ctx.param_like(len(args)) + grads = compute_gradient(fxn, root_grad, set(params.values())) ret: list[UOp|None] = [None] for i in range(len(args)): if (p:=params.get(i, None)) is not None and p in grads: # TODO: compact the args and remove unused ones assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad" - ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)) + ret.append(grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)) else: ret.append(None) return tuple(ret) @@ -56,6 +61,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)), (UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)), (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), + (UPat(Ops.TUPLE), lambda ctx: ctx.src), # NOTE: this is only correct when the KERNEL has a single output (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)), # gradient on CALL: use provided grad_fxn or auto-differentiate @@ -72,9 +78,17 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])) def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: - grads = {root: root_grad} + grads: dict[UOp, UOp] = {root: root_grad} for t0 in reversed(_deepwalk(root, targets)): if t0 not in grads: continue + # GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL + if t0.op is Ops.GETTUPLE: + k = t0.src[0] # the CALL + assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE + n_outputs = len(k.src[0].src) + prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs)) + grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs))) + continue lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" @@ -87,8 +101,4 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp # we add the backward metadata to everything new in the graph for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])): all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata - # end any ranges on grads with a reduce sum - for k,v in grads.items(): - if len(v.ranges): - grads[k] = v.reduce(*v.ranges, arg=Ops.ADD) return grads