test assign (#3798)

* Reapply "add failing assign test (#3796)" (#3797)

This reverts commit 1e1beb888c.

* no realized check
This commit is contained in:
George Hotz
2024-03-18 08:58:04 -07:00
committed by GitHub
parent 1e1beb888c
commit 0183a05f0a
2 changed files with 24 additions and 4 deletions

View File

@@ -84,6 +84,29 @@ class TestAssign(unittest.TestCase):
for _ in range(4): f(y)
assert y.item() == 4
def test_assign_changes(self):
a = Tensor.ones(4).contiguous().realize()
old_a = a
a.assign(Tensor.full((4,), 2.).contiguous())
# NOTE: old_a is now 2, and this would match the behavior of pytorch
new = a + old_a
np.testing.assert_allclose(new.numpy(), 4)
@unittest.expectedFailure
def test_assign_diamond(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
new = a + times_a
np.testing.assert_allclose(new.numpy(), 5)
def test_assign_diamond_alt(self):
a = Tensor.ones(4).contiguous().realize()
a.assign(Tensor.full((4,), 2.).contiguous())
times_a = a*3
new = a + times_a
np.testing.assert_allclose(new.numpy(), 8)
def test_assign_kv_cache(self):
bsz, max_context = 2, 8

View File

@@ -59,10 +59,7 @@ class LazyBuffer:
shape = self.shape if shape is None else shape
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=cast_scalar(val, self.dtype)).reshape((1,)*len(shape)).expand(shape)
def assign(self, x:LazyBuffer) -> LazyBuffer:
if self.base.realized is not None or self is not self.base: new_self = self
else: new_self = create_lazybuffer(self.device, self.st, self.dtype, self.op, self.arg, self.srcs, enable_cache=False)
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, new_self))
def assign(self, x:LazyBuffer) -> LazyBuffer: return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, src=(x, self))
def contiguous(self):
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
ret = self.e(LoadOps.CONTIGUOUS)