Revert "switch beautiful_mnist to use new optimizer [pr] (#8231)" (#8233)

This reverts commit e9ee39df22.
This commit is contained in:
George Hotz
2024-12-13 19:07:09 -08:00
committed by GitHub
parent e9ee39df22
commit 37fa38d272
7 changed files with 29 additions and 61 deletions

View File

@@ -26,11 +26,10 @@ l1n, l2n = l1.numpy(), l2.numpy()
from tinygrad.nn.optim import SGD
optim = SGD([l1, l2])
Tensor.training = True
X, Y = X_train[(samples:=Tensor.randint(128, high=X_train.shape[0]))], Y_train[samples]
optim.zero_grad()
model(X).sparse_categorical_crossentropy(Y).backward()
optim.schedule_step() # this will step the optimizer without running realize
optim._step() # this will step the optimizer without running realize
# *****
# 3. Create a schedule.

View File

@@ -31,8 +31,4 @@
::: tinygrad.Tensor.shard_
::: tinygrad.Tensor.contiguous
::: tinygrad.Tensor.contiguous_backward
## Gradient
::: tinygrad.Tensor.gradient
::: tinygrad.Tensor.backward

View File

@@ -26,9 +26,11 @@ if __name__ == "__main__":
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
opt.step(loss:=model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]))
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
@TinyJit

View File

@@ -1,13 +1,11 @@
from typing import Callable
import unittest, math
import numpy as np
import jax
import jax.numpy as jnp
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.gradient import gradient
from tinygrad.nn.optim import SGD, Adam
class TestGradient(unittest.TestCase):
def _cmp_nan_okay(self, x, y):
@@ -62,29 +60,15 @@ class TestGradient(unittest.TestCase):
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
# NOTE: this contiguous shouldn't be needed. gradient should go to base
x = Tensor.eye(3).contiguous()
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist())
print(dy.tolist())
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_optim(self):
with Tensor.train():
w = Tensor([1.0, 2.0, 3.0])
SGD([w], lr=0.1).step(w.sum())
np.testing.assert_almost_equal(w.tolist(), [0.9, 1.9, 2.9])
def test_optim_rng(self):
with Tensor.train():
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
Adam([w], lr=0.1).step((x*w).sum())
if __name__ == '__main__':
unittest.main()

View File

@@ -36,9 +36,6 @@ pm_gradient = PatternMatcher([
# TODO: this cast can be removed by putting the casts around the EXPAND
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
# there's no gradient for...is this ASSIGN?
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
])
# copied from tensor.py, get relevant toposort of gradients
@@ -59,8 +56,8 @@ def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}")
assert len(lgrads) == len(t0.src)
for k,v in zip(t0.src, lgrads):
if v is None: continue
if k in grads: grads[k] = grads[k] + v

View File

@@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List, Optional
from tinygrad.helpers import dedup, flatten, getenv, unwrap
from typing import List
from tinygrad.helpers import dedup, flatten, getenv
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype
@@ -27,20 +27,20 @@ class Optimizer:
"""
for param in self.params: param.grad = None
def step(self, loss:Optional[Tensor]=None):
def step(self):
"""
Performs a single optimization step.
"""
Tensor.realize(*self.schedule_step(loss))
def schedule_step(self, loss:Optional[Tensor]=None) -> List[Tensor]:
Tensor.realize(*self.schedule_step())
def schedule_step(self) -> List[Tensor]:
"""
Returns the tensors that need to be realized to perform a single optimization step.
"""
assert Tensor.training, (
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
return self._step([unwrap(t.grad) for t in self.params] if loss is None else loss.gradient(*self.params))+self.params+self.buffers
def _step(self, grads:List[Tensor]) -> List[Tensor]: raise NotImplementedError
return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError
class OptimizerGroup(Optimizer):
"""
@@ -51,7 +51,7 @@ class OptimizerGroup(Optimizer):
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
def __getitem__(self, i): return self.optimizers[i]
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
def schedule_step(self, loss:Optional[Tensor]=None) -> List[Tensor]: return [x for o in self.optimizers for x in o.schedule_step(loss)]
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
@@ -76,10 +76,12 @@ class LARS(Optimizer):
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
def _step(self, grads:List[Tensor]) -> List[Tensor]:
for i, (t, g) in enumerate(zip(self.params, grads)):
# contiguous is needed since the grads can allegedly form a "diamond". TODO: is this fixed?
g = g.contiguous()
def _step(self) -> List[Tensor]:
for i, t in enumerate(self.params):
assert t.grad is not None
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = t.grad.contiguous()
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
@@ -128,12 +130,13 @@ class LAMB(Optimizer):
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
def _step(self, grads:List[Tensor]) -> List[Tensor]:
def _step(self) -> List[Tensor]:
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, (t, g) in enumerate(zip(self.params, grads)):
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()

View File

@@ -866,19 +866,6 @@ class Tensor(SimpleMathTrait):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
```python exec="true" source="above" session="tensor" result="python"
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist()) # dz/dx
print(dy.tolist()) # dz/dy
```
"""
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
target_uops: List[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
return [Tensor(y) for y in gradient(self.lazydata, target_uops)]