From 9fb573d73c8da26047e1626812fa73cc84352505 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 27 Mar 2024 20:09:59 +0200 Subject: [PATCH] DAG cycle asserts (#3955) * assert cycles * these are cycle errors * flip to positive --- test/test_assign.py | 30 +++++++++++++++--------------- tinygrad/engine/schedule.py | 1 + 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/test/test_assign.py b/test/test_assign.py index 6e3ce3692d..88fc8522b3 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -92,13 +92,13 @@ class TestAssign(unittest.TestCase): new = a + old_a np.testing.assert_allclose(new.numpy(), 4) - @unittest.expectedFailure def test_assign_diamond(self): - a = Tensor.ones(4).contiguous().realize() - times_a = a*3 - a.assign(Tensor.full((4,), 2.).contiguous()) - new = a + times_a - np.testing.assert_allclose(new.numpy(), 5) + with self.assertRaises(AssertionError): + a = Tensor.ones(4).contiguous().realize() + times_a = a*3 + a.assign(Tensor.full((4,), 2.).contiguous()) + new = a + times_a + np.testing.assert_allclose(new.numpy(), 5) def test_assign_diamond_possible(self): a = Tensor.ones(4).contiguous().realize() @@ -136,16 +136,16 @@ class TestAssign(unittest.TestCase): np.testing.assert_allclose(a.numpy(), 5) np.testing.assert_allclose(b.numpy(), 8) - @unittest.expectedFailure def test_crossunder_assign(self): - a = Tensor.full((4,), 2).contiguous().realize() - b = Tensor.full((4,), 3).contiguous().realize() - c = a+9 - a += b - b += c - Tensor.corealize([a,b]) - np.testing.assert_allclose(a.numpy(), 2+3) - np.testing.assert_allclose(b.numpy(), 3+2+9) + with self.assertRaises(AssertionError): + a = Tensor.full((4,), 2).contiguous().realize() + b = Tensor.full((4,), 3).contiguous().realize() + c = a+9 + a += b + b += c + Tensor.corealize([a,b]) + np.testing.assert_allclose(a.numpy(), 2+3) + np.testing.assert_allclose(b.numpy(), 3+2+9) def test_assign_kv_cache(self): bsz, max_context = 2, 8 diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2f988e3341..b01a7ca376 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -202,6 +202,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) + assert all(degree == 0 for degree in in_degree.values()), "Cycle detected in the graph" # confirm everything was scheduled assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}" return schedule