Dedup params in Optimizer (#1047)

* Dedup params in optimizer

 * Passing the same tensor multiple times in the set of learnable params passed to optimizers can result in models completely failing to learn, but no errors are produced.  This dedups tensors to avoid the problem.

* Fix types

* Use new variable to satisfy linter

* Use `helpers.dedup` instead of `set()` to dedup params

* Add test for duped params in optimizers
This commit is contained in:
Casey Primozic
2023-06-26 00:49:23 -07:00
committed by GitHub
parent 5d3310ce56
commit 52b7105f87
2 changed files with 23 additions and 3 deletions

View File

@@ -1,4 +1,6 @@
import numpy as np
from tinygrad.helpers import dtypes
from tinygrad.nn import Linear
import torch
import unittest
from tinygrad.tensor import Tensor
@@ -42,7 +44,7 @@ class TestOptim(unittest.TestCase):
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
#
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
@@ -74,5 +76,22 @@ class TestOptim(unittest.TestCase):
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 1e-5, 3e-4)
def test_duped_weights(self):
for Opt in [Adam, AdamW, SGD]:
losses = []
for i in range(2):
w = Tensor(x_init.copy())
opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1)
loss = None
for _ in range(3):
loss = w.sum()
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.numpy())
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
if __name__ == '__main__':
unittest.main()