simple assign tests (#3807)

This commit is contained in:
George Hotz
2024-03-18 13:57:01 -07:00
committed by GitHub
parent a0ab755317
commit d8296d4a3f

View File

@@ -107,6 +107,32 @@ class TestAssign(unittest.TestCase):
new = a + times_a
np.testing.assert_allclose(new.numpy(), 8)
def test_double_assign(self):
a = Tensor.ones(4).contiguous().realize()
a += 1
a += 1
np.testing.assert_allclose(a.numpy(), 3)
def test_crossover_assign(self):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
a += b
b += a
Tensor.corealize([a,b])
np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8)
@unittest.expectedFailure
def test_crossunder_assign(self):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9
a += b
b += c
Tensor.corealize([a,b])
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)
def test_assign_kv_cache(self):
bsz, max_context = 2, 8