From 1e1beb888cf3c336f384f3ce11e8043cfaea02b6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 18 Mar 2024 08:55:36 -0700 Subject: [PATCH] Revert "add failing assign test (#3796)" (#3797) This reverts commit 2dea12832c1a2d5e94b9fc72e50170da9fdc9ff1. --- test/test_assign.py | 23 ----------------------- tinygrad/lazy.py | 5 ++++- tinygrad/realize.py | 1 - 3 files changed, 4 insertions(+), 25 deletions(-) diff --git a/test/test_assign.py b/test/test_assign.py index 1eb7ff42b0..2319b92b53 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -84,29 +84,6 @@ class TestAssign(unittest.TestCase): for _ in range(4): f(y) assert y.item() == 4 - def test_assign_changes(self): - a = Tensor.ones(4).contiguous().realize() - old_a = a - a.assign(Tensor.full((4,), 2.).contiguous()) - # NOTE: old_a is now 2, and this would match the behavior of pytorch - new = a + old_a - np.testing.assert_allclose(new.numpy(), 4) - - @unittest.expectedFailure - def test_assign_diamond(self): - a = Tensor.ones(4).contiguous().realize() - times_a = a*3 - a.assign(Tensor.full((4,), 2.).contiguous()) - new = a + times_a - np.testing.assert_allclose(new.numpy(), 5) - - def test_assign_diamond_alt(self): - a = Tensor.ones(4).contiguous().realize() - a.assign(Tensor.full((4,), 2.).contiguous()) - times_a = a*3 - new = a + times_a - np.testing.assert_allclose(new.numpy(), 8) - def test_assign_kv_cache(self): bsz, max_context = 2, 8 diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e91332d102..807199146a 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -59,7 +59,10 @@ class LazyBuffer: shape = self.shape if shape is None else shape return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape) - def assign(self, x:LazyBuffer) -> LazyBuffer: return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self)) + def assign(self, x:LazyBuffer) -> LazyBuffer: + if self.base.realized is not None or self is not self.base: new_self = self + else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False) + return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self)) def contiguous(self): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): ret = self.e(LoadOps.CONTIGUOUS) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 9656d38fbe..303acb3f81 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -109,7 +109,6 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it if buf.op in {LoadOps.CONTIGUOUS, LoadOps.ASSIGN}: assert first - assert buf.op is not LoadOps.ASSIGN or buf.srcs[1].base.realized is not None, "assign must be already realized to schedule" return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1].base if buf.op is LoadOps.ASSIGN else None)