diff --git a/test/test_assign.py b/test/test_assign.py index f72f0694fe..44a45b383c 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad import dtypes +from tinygrad import dtypes, TinyJit N = 200 # has to be bigger than the cache to fail @@ -20,6 +20,55 @@ class TestAssign(unittest.TestCase): assert ba1 == ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N))) + def test_assign_add(self): + def f(x): + x += 1 + x.realize() + x = Tensor([0]) + f(x) + assert x.item() == 1 + + def test_assign_add_twice(self): + # NOTE: this has two kernels + def f(x): + x += 1 + x += 1 + x.realize() + x = Tensor([0]) + f(x) + assert x.item() == 2 + + def test_assign_add_double(self): + def f(x): + x += 1 + x.realize() + x = Tensor([0]) + f(x) + assert (out:=x.item()) == 1, f"expected 1, got {out}" + x = Tensor([0]) + f(x) + assert (out:=x.item()) == 1, f"expected 1, got {out}" + + def test_assign_add_jit(self): + @TinyJit + def f(x): + x += 1 + x.realize() + x = Tensor([0]) + for _ in range(5): f(x) + assert x.item() == 5 + + def test_assign_add_jit_other(self): + @TinyJit + def f(x): + x += 1 + x.realize() + x = Tensor([0]) + for _ in range(5): f(x) + y = Tensor([0]) + for _ in range(4): f(y) + assert y.item() == 4 + def test_permuted_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) diff --git a/test/testextra/test_lr_scheduler.py b/test/testextra/test_lr_scheduler.py index 0d2db9319f..24d09989e1 100644 --- a/test/testextra/test_lr_scheduler.py +++ b/test/testextra/test_lr_scheduler.py @@ -3,7 +3,8 @@ import torch import unittest from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters -from tinygrad.nn.optim import Adam +from tinygrad.nn.optim import Adam, SGD +from tinygrad.helpers import DEBUG from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR from extra.training import train, evaluate from extra.datasets import fetch_mnist @@ -52,11 +53,14 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None): return lrs class TestLrScheduler(unittest.TestCase): - def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol): + def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True): accs = opts.pop('accs', None) test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] test_tensor.mean().backward() - tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) + if adam: + tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) + else: + tinygrad_optim, torch_optim = SGD([test_tensor], lr=0.01), torch.optim.SGD([torch.tensor([0.], requires_grad=True)], lr=0.01) tinygrad_sched, torch_sched = tinygrad_sched(tinygrad_optim, **opts), torch_sched(torch_optim, **opts) tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs) @@ -64,8 +68,8 @@ class TestLrScheduler(unittest.TestCase): np.testing.assert_allclose(tinygrad_lrs, torch_lrs, atol=atol, rtol=rtol) - def _test_multisteplr(self, epochs, opts, atol, rtol): - self._test_lr_scheduler(MultiStepLR, torch.optim.lr_scheduler.MultiStepLR, epochs, opts, atol, rtol) + def _test_multisteplr(self, epochs, opts, atol, rtol, adam=True): + self._test_lr_scheduler(MultiStepLR, torch.optim.lr_scheduler.MultiStepLR, epochs, opts, atol, rtol, adam=adam) def _test_reducelronplateau(self, epochs, opts, atol, rtol): opts['accs'] = np.random.randn(epochs) self._test_lr_scheduler(ReduceLROnPlateau, torch.optim.lr_scheduler.ReduceLROnPlateau, epochs, opts, atol, rtol) @@ -89,6 +93,14 @@ class TestLrScheduler(unittest.TestCase): def test_cosineannealinglr(self): self._test_cosineannealinglr(100, {}, 1e-6, 1e-6) def test_cosineannealinglr_eta_min(self): self._test_cosineannealinglr(100, {'eta_min': 0.001}, 1e-6, 1e-6) + def test_multistep_2step(self): + # was making this fail with LRU=1, some issue with epoch_counter + if DEBUG>=2: print("first") + self._test_multisteplr(1, {'milestones': [1]}, 1e-6, 1e-6, adam=False) + if DEBUG>=2: print("second") + self._test_multisteplr(1, {'milestones': [1], 'gamma': 0.133}, 1e-6, 1e-6, adam=False) + if DEBUG>=2: print("third") + def test_onecyclelr(self): self._test_onecyclelr(1000, {'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'max_lr':1e-5}, 1e-6, 1e-6) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7fd413475b..e94a898a3e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -144,11 +144,12 @@ class Tensor: self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) + if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") + if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? - if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"): if isinstance(self.lazydata, MultiLazyBuffer): for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized