mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove realize from optimizer (#2880)
* remove realize from optimizer * one still needed * opt realize
This commit is contained in:
@@ -10,6 +10,13 @@ x_init = np.random.randn(1,4).astype(np.float32)
|
||||
W_init = np.random.randn(4,4).astype(np.float32)
|
||||
m_init = np.random.randn(1,4).astype(np.float32)
|
||||
|
||||
class TeenyNet:
|
||||
def __init__(self, tensor):
|
||||
self.x = tensor(x_init.copy(), requires_grad=True)
|
||||
self.W = tensor(W_init.copy(), requires_grad=True)
|
||||
def forward(self):
|
||||
return (self.x * self.W).sum()
|
||||
|
||||
class TinyNet:
|
||||
def __init__(self, tensor):
|
||||
self.x = tensor(x_init.copy(), requires_grad=True)
|
||||
@@ -23,8 +30,8 @@ class TinyNet:
|
||||
out = out.mul(self.m).add(self.m).sum()
|
||||
return out
|
||||
|
||||
def step(tensor, optim, steps=1, kwargs={}):
|
||||
net = TinyNet(tensor)
|
||||
def step(tensor, optim, steps=1, teeny=False, **kwargs):
|
||||
net = TeenyNet(tensor) if teeny else TinyNet(tensor)
|
||||
optim = optim([net.x, net.W], **kwargs)
|
||||
for _ in range(steps):
|
||||
out = net.forward()
|
||||
@@ -37,14 +44,17 @@ def step(tensor, optim, steps=1, kwargs={}):
|
||||
class TestOptim(unittest.TestCase):
|
||||
|
||||
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
|
||||
for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts),
|
||||
step(torch.tensor, torch_optim, steps, kwargs=opts)):
|
||||
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),
|
||||
step(torch.tensor, torch_optim, steps, **opts)):
|
||||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
||||
|
||||
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
|
||||
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
|
||||
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
|
||||
|
||||
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
|
||||
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
|
||||
|
||||
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
|
||||
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
|
||||
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
|
||||
|
||||
@@ -36,6 +36,14 @@ def nm(x):
|
||||
node_count += 1
|
||||
return x.node_id
|
||||
|
||||
buf_count = 0
|
||||
def bm(x):
|
||||
global buf_count
|
||||
if not hasattr(x, 'buf_id'):
|
||||
setattr(x, 'buf_id', buf_count)
|
||||
buf_count += 1
|
||||
return x.buf_id
|
||||
|
||||
def get_sop(op: List[Op]):
|
||||
op = [x for x in op if x not in BufferOps]
|
||||
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
@@ -51,7 +59,7 @@ def realized_lazybuffer(lb, num):
|
||||
init_graph()
|
||||
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
||||
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
||||
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
|
||||
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{bm(lb.realized)}"'
|
||||
|
||||
def log_lazybuffer(lb, scheduled=False):
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
||||
@@ -66,8 +74,7 @@ def log_lazybuffer(lb, scheduled=False):
|
||||
lb = lb.base
|
||||
if lb.realized is None:
|
||||
for x in lb.srcs:
|
||||
if nm(x) not in G.nodes:
|
||||
G.add_node(nm(x), label=f'"{str(x.base.realized)[5:-1].replace(" ", chr(10))}"', style='filled', fillcolor="#f0c08080")
|
||||
log_lazybuffer(x)
|
||||
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
||||
label = '"' + \
|
||||
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
||||
@@ -75,7 +82,10 @@ def log_lazybuffer(lb, scheduled=False):
|
||||
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
||||
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
||||
|
||||
else:
|
||||
if nm(lb) not in G.nodes:
|
||||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer":
|
||||
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
@@ -13,7 +13,7 @@ class Optimizer:
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.device = self.params[0].device
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
@@ -32,9 +32,12 @@ class SGD(Optimizer):
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize() + self.wd * t.detach()
|
||||
# this is needed since the grads can form a "diamond"
|
||||
# TODO: fix this in lazy.py
|
||||
t.grad.realize()
|
||||
g = t.grad + self.wd * t.detach()
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
@@ -51,12 +54,11 @@ class LAMB(Optimizer):
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self) -> None:
|
||||
self.t.assign(self.t + 1).realize()
|
||||
self.t.assign(self.t + 1)
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
m_hat = self.m[i] / (1.0 - self.b1**self.t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2**self.t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
|
||||
Reference in New Issue
Block a user