mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
compute gradient [pr] (#8237)
* compute gradient [pr] * schedule_step_with_grads * second deriv works
This commit is contained in:
@@ -77,38 +77,23 @@ class TestTinygrad(unittest.TestCase):
|
||||
np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6)
|
||||
np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_second_order_backward_pass(self):
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init)
|
||||
m = torch.tensor(m_init, requires_grad=True)
|
||||
out = x.mul(m).sum()
|
||||
# use retain graph so we can compute second order derivatives later
|
||||
out.backward(retain_graph=True)
|
||||
# save first-order gradient (dO/dm). they still contain graph information on how they were constructed wrt x and W
|
||||
grad_m = m.grad
|
||||
# zero gradients so second-order gradients are correct
|
||||
m.grad = None
|
||||
# compute second-order gradients
|
||||
grad_m.sum().backward(retain_graph=True)
|
||||
|
||||
# d2O/dm2
|
||||
second_grad_m = m.grad
|
||||
return second_grad_m.numpy()
|
||||
x_val = torch.tensor([2.0], requires_grad=True)
|
||||
f = x_val**3
|
||||
first_derivative = torch.autograd.grad(outputs=f, inputs=x_val, create_graph=True)[0]
|
||||
second_derivative = torch.autograd.grad(outputs=first_derivative, inputs=x_val)[0]
|
||||
# d^2f/dx^2 = 6x = 6*2 = 12
|
||||
return second_derivative.numpy()
|
||||
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
out = x.mul(m).sum()
|
||||
out.backward()
|
||||
grad_m = m.grad
|
||||
m.grad = None
|
||||
grad_m.sum().backward()
|
||||
second_grad_m = m.grad # currently, this will be None (incorrect)
|
||||
return second_grad_m.numpy()
|
||||
x_val = Tensor(2.0)
|
||||
f = x_val**3
|
||||
first_derivative = f.gradient(x_val)[0]
|
||||
second_derivative = first_derivative.gradient(x_val)[0]
|
||||
return second_derivative.numpy()
|
||||
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
np.testing.assert_allclose(test_tinygrad(), test_pytorch(), atol=1e-5)
|
||||
|
||||
# passing `gradient` to backward
|
||||
def test_backward_pass_vjp(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ import jax.numpy as jnp
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.gradient import gradient
|
||||
from tinygrad.gradient import compute_gradient
|
||||
|
||||
class TestGradient(unittest.TestCase):
|
||||
def _cmp_nan_okay(self, x, y):
|
||||
@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
|
||||
|
||||
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
gx = gradient(f(x), [x])[0]
|
||||
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
|
||||
gf = jax.grad(f if jf is None else jf)
|
||||
|
||||
for val in [-5., -2.0, 0.0, 2.0, 5.]:
|
||||
@@ -24,7 +24,8 @@ class TestGradient(unittest.TestCase):
|
||||
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
|
||||
gx, gy = gradient(f(x, y), [x, y])
|
||||
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
|
||||
gx, gy = grads[x], grads[y]
|
||||
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))
|
||||
|
||||
for valx in [-5., -2.0, 0.0, 2.0, 5.]:
|
||||
@@ -54,9 +55,7 @@ class TestGradient(unittest.TestCase):
|
||||
def test_chain_binop(self): self._test_two_input_function(lambda x,y: (x*y)+x*y)
|
||||
def test_big_add_sin(self): self._test_two_input_function(lambda x,y: x.sin()+3.0/y, lambda x,y: jnp.sin(x)+3.0/y)
|
||||
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
|
||||
|
||||
# TODO: this isn't working
|
||||
#def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: jnp.where(x<y, x, y))
|
||||
def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: jnp.where(x<y, x, y))
|
||||
|
||||
class TestTensorGradient(unittest.TestCase):
|
||||
def test_example(self):
|
||||
@@ -72,5 +71,11 @@ class TestTensorGradient(unittest.TestCase):
|
||||
w = Tensor.randn((3,))
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(w)
|
||||
|
||||
def test_with_custom_gradient(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
z = (x * x).sum()
|
||||
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
|
||||
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -54,8 +54,8 @@ def _deepwalk(root:UOp, targets:list[UOp]):
|
||||
yield node
|
||||
return list(_walk(root, set()))
|
||||
|
||||
def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
|
||||
grads = {root: root.const_like(1.0)}
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
@@ -65,8 +65,4 @@ def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
|
||||
if v is None: continue
|
||||
if k in grads: grads[k] = grads[k] + v
|
||||
else: grads[k] = v
|
||||
ret = [grads.get(x, None) for x in targets]
|
||||
for i,x in enumerate(ret):
|
||||
if x is None: raise RuntimeError(f"{targets[i]}\n\nnot found in\n\n{root}")
|
||||
return ret
|
||||
|
||||
return grads
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup, flatten, getenv
|
||||
from tinygrad.helpers import dedup, flatten, getenv, unwrap
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
|
||||
@@ -39,8 +39,8 @@ class Optimizer:
|
||||
assert Tensor.training, (
|
||||
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
||||
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
||||
return self._step()+self.params+self.buffers
|
||||
def _step(self) -> List[Tensor]: raise NotImplementedError
|
||||
return self.schedule_step_with_grads([unwrap(t.grad) for t in self.params])+self.params+self.buffers
|
||||
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]: raise NotImplementedError
|
||||
|
||||
class OptimizerGroup(Optimizer):
|
||||
"""
|
||||
@@ -76,12 +76,11 @@ class LARS(Optimizer):
|
||||
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
||||
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
|
||||
def _step(self) -> List[Tensor]:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]:
|
||||
for i, (t, g) in enumerate(zip(self.params, grads)):
|
||||
# contiguous is needed since the grads can allegedly form a "diamond"
|
||||
# TODO: fix this in lazy.py
|
||||
g = t.grad.contiguous()
|
||||
g = g.contiguous()
|
||||
if self.tcoef != 0:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = g.square().sum().sqrt()
|
||||
@@ -130,13 +129,12 @@ class LAMB(Optimizer):
|
||||
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def _step(self) -> List[Tensor]:
|
||||
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]:
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
for i, (t, g) in enumerate(zip(self.params, grads)):
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
|
||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.gradient import gradient
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -865,7 +865,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def gradient(self, *targets:Tensor) -> list[Tensor]:
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
|
||||
"""
|
||||
Compute the gradient of the targets with respect to self.
|
||||
|
||||
@@ -881,7 +881,13 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
|
||||
target_uops: List[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
|
||||
return [Tensor(y) for y in gradient(self.lazydata, target_uops)]
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
grads = compute_gradient(self.lazydata, self.lazydata.const_like(1) if gradient is None else cast(UOp, gradient.lazydata), target_uops)
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
ret.append(Tensor(y, device=x.device))
|
||||
return ret
|
||||
|
||||
def _deepwalk(self):
|
||||
def _walk(node, visited):
|
||||
|
||||
Reference in New Issue
Block a user