mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
support grads on tuples (#15287)
* support grads on tuples * simpler * grad_fxn works * cleanups * unused
This commit is contained in:
@@ -347,5 +347,31 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
assert t2.tolist() == [3,3,3]
|
||||
def test_tuple_precompile(self): self.test_tuple(True)
|
||||
|
||||
def test_grad_tuple(self, precompile=False):
|
||||
x = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
y = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
@function(precompile=precompile)
|
||||
def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2)
|
||||
t1, t2 = f(x,y)
|
||||
(t1+t2).sum().backward()
|
||||
x.grad.realize(y.grad)
|
||||
def test_grad_tuple_precompile(self): self.test_grad_tuple(True)
|
||||
|
||||
def test_grad_fxn_tuple(self):
|
||||
# grad_fxn for tuple: ctx is a TUPLE UOp with one element per output
|
||||
def grad_fxn(ctx:UOp, call:UOp):
|
||||
# f(u1, u2) = (u1+1, u2+2), ctx.src = (d_out0, d_out1)
|
||||
# df/du1 = d_out0, df/du2 = d_out1
|
||||
return (ctx.src[0], ctx.src[1])
|
||||
|
||||
x = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
y = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
@function(grad_fxn=grad_fxn)
|
||||
def f(u1:Tensor, u2:Tensor): return (u1+1, u2+2)
|
||||
t1, t2 = f(x, y)
|
||||
(t1+t2).sum().backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
|
||||
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -20,10 +20,11 @@ pm_ctx = PatternMatcher([
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||
self.fxn = fxn
|
||||
self.precompile = precompile
|
||||
self.allow_implicit = allow_implicit
|
||||
self.grad_fxn = grad_fxn
|
||||
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
@@ -69,7 +70,7 @@ class _function(Generic[ReturnType]):
|
||||
#call = assigned.call(*call_uops, buffer, name=name)
|
||||
#ret = buffer.after(call)
|
||||
|
||||
fret = uret.call(*call_uops, name=name, precompile=self.precompile)
|
||||
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile)
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
||||
else:
|
||||
@@ -77,9 +78,11 @@ class _function(Generic[ReturnType]):
|
||||
|
||||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ...
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True,
|
||||
grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
|
||||
@overload
|
||||
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit)
|
||||
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit)
|
||||
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True,
|
||||
grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
||||
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
||||
|
||||
@@ -15,16 +15,21 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
# auto-differentiate the function
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
|
||||
if fxn.op is Ops.TUPLE:
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
|
||||
else:
|
||||
grad_args = (ctx,)
|
||||
root_grad = ctx.param_like(len(args))
|
||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||
ret: list[UOp|None] = [None]
|
||||
for i in range(len(args)):
|
||||
if (p:=params.get(i, None)) is not None and p in grads:
|
||||
# TODO: compact the args and remove unused ones
|
||||
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
|
||||
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward))
|
||||
ret.append(grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward))
|
||||
else:
|
||||
ret.append(None)
|
||||
return tuple(ret)
|
||||
@@ -56,6 +61,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
# gradient on CALL: use provided grad_fxn or auto-differentiate
|
||||
@@ -72,9 +78,17 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
|
||||
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
grads: dict[UOp, UOp] = {root: root_grad}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
||||
if t0.op is Ops.GETTUPLE:
|
||||
k = t0.src[0] # the CALL
|
||||
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
|
||||
n_outputs = len(k.src[0].src)
|
||||
prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs))
|
||||
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
||||
continue
|
||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||
@@ -87,8 +101,4 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
# we add the backward metadata to everything new in the graph
|
||||
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
||||
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
||||
# end any ranges on grads with a reduce sum
|
||||
for k,v in grads.items():
|
||||
if len(v.ranges):
|
||||
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
|
||||
return grads
|
||||
|
||||
Reference in New Issue
Block a user