compute gradient [pr] (#8237)

* compute gradient [pr]

* schedule_step_with_grads

* second deriv works
This commit is contained in:
George Hotz
2024-12-13 20:46:01 -08:00
committed by GitHub
parent 0708a169dd
commit 734f2c5344
5 changed files with 45 additions and 55 deletions

View File

@@ -77,38 +77,23 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(xgrad2, xgrad * 2., atol=1e-6)
np.testing.assert_allclose(wgrad2, wgrad * 2., atol=1e-6)
@unittest.expectedFailure
def test_second_order_backward_pass(self):
def test_pytorch():
x = torch.tensor(x_init)
m = torch.tensor(m_init, requires_grad=True)
out = x.mul(m).sum()
# use retain graph so we can compute second order derivatives later
out.backward(retain_graph=True)
# save first-order gradient (dO/dm). they still contain graph information on how they were constructed wrt x and W
grad_m = m.grad
# zero gradients so second-order gradients are correct
m.grad = None
# compute second-order gradients
grad_m.sum().backward(retain_graph=True)
# d2O/dm2
second_grad_m = m.grad
return second_grad_m.numpy()
x_val = torch.tensor([2.0], requires_grad=True)
f = x_val**3
first_derivative = torch.autograd.grad(outputs=f, inputs=x_val, create_graph=True)[0]
second_derivative = torch.autograd.grad(outputs=first_derivative, inputs=x_val)[0]
# d^2f/dx^2 = 6x = 6*2 = 12
return second_derivative.numpy()
def test_tinygrad():
x = Tensor(x_init)
m = Tensor(m_init, requires_grad=True)
out = x.mul(m).sum()
out.backward()
grad_m = m.grad
m.grad = None
grad_m.sum().backward()
second_grad_m = m.grad # currently, this will be None (incorrect)
return second_grad_m.numpy()
x_val = Tensor(2.0)
f = x_val**3
first_derivative = f.gradient(x_val)[0]
second_derivative = first_derivative.gradient(x_val)[0]
return second_derivative.numpy()
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
np.testing.assert_allclose(test_tinygrad(), test_pytorch(), atol=1e-5)
# passing `gradient` to backward
def test_backward_pass_vjp(self):

View File

@@ -5,7 +5,7 @@ import jax.numpy as jnp
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.gradient import gradient
from tinygrad.gradient import compute_gradient
class TestGradient(unittest.TestCase):
def _cmp_nan_okay(self, x, y):
@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
gx = gradient(f(x), [x])[0]
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
gf = jax.grad(f if jf is None else jf)
for val in [-5., -2.0, 0.0, 2.0, 5.]:
@@ -24,7 +24,8 @@ class TestGradient(unittest.TestCase):
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
gx, gy = gradient(f(x, y), [x, y])
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
gx, gy = grads[x], grads[y]
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))
for valx in [-5., -2.0, 0.0, 2.0, 5.]:
@@ -54,9 +55,7 @@ class TestGradient(unittest.TestCase):
def test_chain_binop(self): self._test_two_input_function(lambda x,y: (x*y)+x*y)
def test_big_add_sin(self): self._test_two_input_function(lambda x,y: x.sin()+3.0/y, lambda x,y: jnp.sin(x)+3.0/y)
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
# TODO: this isn't working
#def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: jnp.where(x<y, x, y))
def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: jnp.where(x<y, x, y))
class TestTensorGradient(unittest.TestCase):
def test_example(self):
@@ -72,5 +71,11 @@ class TestTensorGradient(unittest.TestCase):
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
if __name__ == '__main__':
unittest.main()

View File

@@ -54,8 +54,8 @@ def _deepwalk(root:UOp, targets:list[UOp]):
yield node
return list(_walk(root, set()))
def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
grads = {root: root.const_like(1.0)}
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
@@ -65,8 +65,4 @@ def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
if v is None: continue
if k in grads: grads[k] = grads[k] + v
else: grads[k] = v
ret = [grads.get(x, None) for x in targets]
for i,x in enumerate(ret):
if x is None: raise RuntimeError(f"{targets[i]}\n\nnot found in\n\n{root}")
return ret
return grads

View File

@@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup, flatten, getenv
from tinygrad.helpers import dedup, flatten, getenv, unwrap
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype
@@ -39,8 +39,8 @@ class Optimizer:
assert Tensor.training, (
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError
return self.schedule_step_with_grads([unwrap(t.grad) for t in self.params])+self.params+self.buffers
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]: raise NotImplementedError
class OptimizerGroup(Optimizer):
"""
@@ -76,12 +76,11 @@ class LARS(Optimizer):
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
def _step(self) -> List[Tensor]:
for i, t in enumerate(self.params):
assert t.grad is not None
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]:
for i, (t, g) in enumerate(zip(self.params, grads)):
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = t.grad.contiguous()
g = g.contiguous()
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
@@ -130,13 +129,12 @@ class LAMB(Optimizer):
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
def _step(self) -> List[Tensor]:
def schedule_step_with_grads(self, grads:List[Tensor]) -> List[Tensor]:
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
for i, (t, g) in enumerate(zip(self.params, grads)):
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()

View File

@@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import MultiLazyBuffer
from tinygrad.gradient import gradient
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.engine.realize import run_schedule
@@ -865,7 +865,7 @@ class Tensor(SimpleMathTrait):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor) -> list[Tensor]:
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
@@ -881,7 +881,13 @@ class Tensor(SimpleMathTrait):
"""
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
target_uops: List[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
return [Tensor(y) for y in gradient(self.lazydata, target_uops)]
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
grads = compute_gradient(self.lazydata, self.lazydata.const_like(1) if gradient is None else cast(UOp, gradient.lazydata), target_uops)
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
ret.append(Tensor(y, device=x.device))
return ret
def _deepwalk(self):
def _walk(node, visited):