support grads on tuples (#15287)

* support grads on tuples

* simpler

* grad_fxn works

* cleanups

* unused
This commit is contained in:
George Hotz
2026-03-16 17:39:34 +08:00
committed by GitHub
parent 20799df10b
commit 476276f4b4
3 changed files with 54 additions and 15 deletions

View File

@@ -347,5 +347,31 @@ class TestFunctionTuple(unittest.TestCase):
assert t2.tolist() == [3,3,3]
def test_tuple_precompile(self): self.test_tuple(True)
def test_grad_tuple(self, precompile=False):
x = Tensor.ones(3, requires_grad=True).contiguous()
y = Tensor.ones(3, requires_grad=True).contiguous()
@function(precompile=precompile)
def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2)
t1, t2 = f(x,y)
(t1+t2).sum().backward()
x.grad.realize(y.grad)
def test_grad_tuple_precompile(self): self.test_grad_tuple(True)
def test_grad_fxn_tuple(self):
# grad_fxn for tuple: ctx is a TUPLE UOp with one element per output
def grad_fxn(ctx:UOp, call:UOp):
# f(u1, u2) = (u1+1, u2+2), ctx.src = (d_out0, d_out1)
# df/du1 = d_out0, df/du2 = d_out1
return (ctx.src[0], ctx.src[1])
x = Tensor.ones(3, requires_grad=True).contiguous()
y = Tensor.ones(3, requires_grad=True).contiguous()
@function(grad_fxn=grad_fxn)
def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2)
t1, t2 = f(x, y)
(t1+t2).sum().backward()
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
if __name__ == '__main__':
unittest.main()

View File

@@ -20,10 +20,11 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
self.fxn = fxn
self.precompile = precompile
self.allow_implicit = allow_implicit
self.grad_fxn = grad_fxn
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
@@ -69,7 +70,7 @@ class _function(Generic[ReturnType]):
#call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call)
fret = uret.call(*call_uops, name=name, precompile=self.precompile)
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile)
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
else:
@@ -77,9 +78,11 @@ class _function(Generic[ReturnType]):
# overload signatures support both @function and @function(precompile=True) syntax
@overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ...
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True,
grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
@overload
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True):
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit)
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit)
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True,
grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)

View File

@@ -15,16 +15,21 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function
fxn, args = k.src[0], k.src[1:]
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
if fxn.op is Ops.TUPLE:
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
else:
grad_args = (ctx,)
root_grad = ctx.param_like(len(args))
grads = compute_gradient(fxn, root_grad, set(params.values()))
ret: list[UOp|None] = [None]
for i in range(len(args)):
if (p:=params.get(i, None)) is not None and p in grads:
# TODO: compact the args and remove unused ones
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward))
ret.append(grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward))
else:
ret.append(None)
return tuple(ret)
@@ -56,6 +61,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
# gradient on CALL: use provided grad_fxn or auto-differentiate
@@ -72,9 +78,17 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
grads: dict[UOp, UOp] = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
if t0.op is Ops.GETTUPLE:
k = t0.src[0] # the CALL
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
n_outputs = len(k.src[0].src)
prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
@@ -87,8 +101,4 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
# we add the backward metadata to everything new in the graph
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
# end any ranges on grads with a reduce sum
for k,v in grads.items():
if len(v.ranges):
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
return grads