mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test assign (#3798)
* Reapply "add failing assign test (#3796)" (#3797)
This reverts commit 1e1beb888c.
* no realized check
This commit is contained in:
@@ -84,6 +84,29 @@ class TestAssign(unittest.TestCase):
|
||||
for _ in range(4): f(y)
|
||||
assert y.item() == 4
|
||||
|
||||
def test_assign_changes(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
old_a = a
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
# NOTE: old_a is now 2, and this would match the behavior of pytorch
|
||||
new = a + old_a
|
||||
np.testing.assert_allclose(new.numpy(), 4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_diamond(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
times_a = a*3
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
new = a + times_a
|
||||
np.testing.assert_allclose(new.numpy(), 5)
|
||||
|
||||
def test_assign_diamond_alt(self):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
times_a = a*3
|
||||
new = a + times_a
|
||||
np.testing.assert_allclose(new.numpy(), 8)
|
||||
|
||||
def test_assign_kv_cache(self):
|
||||
bsz, max_context = 2, 8
|
||||
|
||||
|
||||
Reference in New Issue
Block a user