From 0183a05f0ac2772c07fdf749e531c75f493c74e1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 18 Mar 2024 08:58:04 -0700 Subject: [PATCH] test assign (#3798) * Reapply "add failing assign test (#3796)" (#3797) This reverts commit 1e1beb888cf3c336f384f3ce11e8043cfaea02b6. * no realized check --- test/test_assign.py | 23 +++++++++++++++++++++++ tinygrad/lazy.py | 5 +---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/test/test_assign.py b/test/test_assign.py index 2319b92b53..1eb7ff42b0 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -84,6 +84,29 @@ class TestAssign(unittest.TestCase): for _ in range(4): f(y) assert y.item() == 4 + def test_assign_changes(self): + a = Tensor.ones(4).contiguous().realize() + old_a = a + a.assign(Tensor.full((4,), 2.).contiguous()) + # NOTE: old_a is now 2, and this would match the behavior of pytorch + new = a + old_a + np.testing.assert_allclose(new.numpy(), 4) + + @unittest.expectedFailure + def test_assign_diamond(self): + a = Tensor.ones(4).contiguous().realize() + times_a = a*3 + a.assign(Tensor.full((4,), 2.).contiguous()) + new = a + times_a + np.testing.assert_allclose(new.numpy(), 5) + + def test_assign_diamond_alt(self): + a = Tensor.ones(4).contiguous().realize() + a.assign(Tensor.full((4,), 2.).contiguous()) + times_a = a*3 + new = a + times_a + np.testing.assert_allclose(new.numpy(), 8) + def test_assign_kv_cache(self): bsz, max_context = 2, 8 diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 807199146a..e91332d102 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -59,10 +59,7 @@ class LazyBuffer: shape = self.shape if shape is None else shape return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape) - def assign(self, x:LazyBuffer) -> LazyBuffer: - if self.base.realized is not None or self is not self.base: new_self = self - else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False) - return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self)) + def assign(self, x:LazyBuffer) -> LazyBuffer: return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self)) def contiguous(self): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): ret = self.e(LoadOps.CONTIGUOUS)