remove realize from optimizer (#2880)

* remove realize from optimizer

* one still needed

* opt realize
This commit is contained in:
George Hotz
2023-12-20 16:42:41 -08:00
committed by GitHub
parent 1765849937
commit e1861ab65e
3 changed files with 38 additions and 16 deletions

View File

@@ -10,6 +10,13 @@ x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TeenyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
def forward(self):
return (self.x * self.W).sum()
class TinyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
@@ -23,8 +30,8 @@ class TinyNet:
out = out.mul(self.m).add(self.m).sum()
return out
def step(tensor, optim, steps=1, kwargs={}):
net = TinyNet(tensor)
def step(tensor, optim, steps=1, teeny=False, **kwargs):
net = TeenyNet(tensor) if teeny else TinyNet(tensor)
optim = optim([net.x, net.W], **kwargs)
for _ in range(steps):
out = net.forward()
@@ -37,14 +44,17 @@ def step(tensor, optim, steps=1, kwargs={}):
class TestOptim(unittest.TestCase):
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts),
step(torch.tensor, torch_optim, steps, kwargs=opts)):
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),
step(torch.tensor, torch_optim, steps, **opts)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)

View File

@@ -36,6 +36,14 @@ def nm(x):
node_count += 1
return x.node_id
buf_count = 0
def bm(x):
global buf_count
if not hasattr(x, 'buf_id'):
setattr(x, 'buf_id', buf_count)
buf_count += 1
return x.buf_id
def get_sop(op: List[Op]):
op = [x for x in op if x not in BufferOps]
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
@@ -51,7 +59,7 @@ def realized_lazybuffer(lb, num):
init_graph()
G.nodes[nm(lb)]['style'] = '"filled,bold"'
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{bm(lb.realized)}"'
def log_lazybuffer(lb, scheduled=False):
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
@@ -66,8 +74,7 @@ def log_lazybuffer(lb, scheduled=False):
lb = lb.base
if lb.realized is None:
for x in lb.srcs:
if nm(x) not in G.nodes:
G.add_node(nm(x), label=f'"{str(x.base.realized)[5:-1].replace(" ", chr(10))}"', style='filled', fillcolor="#f0c08080")
log_lazybuffer(x)
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
@@ -75,7 +82,10 @@ def log_lazybuffer(lb, scheduled=False):
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(lazydata, prefix=""):
if type(lazydata).__name__ == "LazyBuffer":
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")

View File

@@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup
from tinygrad.helpers import dedup, getenv
from tinygrad.tensor import Tensor
class Optimizer:
@@ -13,7 +13,7 @@ class Optimizer:
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
def zero_grad(self):
for param in self.params: param.grad = None
@@ -32,9 +32,12 @@ class SGD(Optimizer):
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize() + self.wd * t.detach()
# this is needed since the grads can form a "diamond"
# TODO: fix this in lazy.py
t.grad.realize()
g = t.grad + self.wd * t.detach()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
@@ -51,12 +54,11 @@ class LAMB(Optimizer):
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self) -> None:
self.t.assign(self.t + 1).realize()
self.t.assign(self.t + 1)
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1**self.t)
v_hat = self.v[i] / (1.0 - self.b2**self.t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()