DAG cycle asserts (#3955)

* assert cycles

* these are cycle errors

* flip to positive
This commit is contained in:
qazal
2024-03-27 20:09:59 +02:00
committed by GitHub
parent bd3a7d068c
commit 9fb573d73c
2 changed files with 16 additions and 15 deletions

View File

@@ -92,13 +92,13 @@ class TestAssign(unittest.TestCase):
new = a + old_a
np.testing.assert_allclose(new.numpy(), 4)
@unittest.expectedFailure
def test_assign_diamond(self):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
new = a + times_a
np.testing.assert_allclose(new.numpy(), 5)
with self.assertRaises(AssertionError):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
new = a + times_a
np.testing.assert_allclose(new.numpy(), 5)
def test_assign_diamond_possible(self):
a = Tensor.ones(4).contiguous().realize()
@@ -136,16 +136,16 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8)
@unittest.expectedFailure
def test_crossunder_assign(self):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9
a += b
b += c
Tensor.corealize([a,b])
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)
with self.assertRaises(AssertionError):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9
a += b
b += c
Tensor.corealize([a,b])
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)
def test_assign_kv_cache(self):
bsz, max_context = 2, 8