mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
DAG cycle asserts (#3955)
* assert cycles * these are cycle errors * flip to positive
This commit is contained in:
@@ -92,13 +92,13 @@ class TestAssign(unittest.TestCase):
|
|||||||
new = a + old_a
|
new = a + old_a
|
||||||
np.testing.assert_allclose(new.numpy(), 4)
|
np.testing.assert_allclose(new.numpy(), 4)
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_assign_diamond(self):
|
def test_assign_diamond(self):
|
||||||
a = Tensor.ones(4).contiguous().realize()
|
with self.assertRaises(AssertionError):
|
||||||
times_a = a*3
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
times_a = a*3
|
||||||
new = a + times_a
|
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||||
np.testing.assert_allclose(new.numpy(), 5)
|
new = a + times_a
|
||||||
|
np.testing.assert_allclose(new.numpy(), 5)
|
||||||
|
|
||||||
def test_assign_diamond_possible(self):
|
def test_assign_diamond_possible(self):
|
||||||
a = Tensor.ones(4).contiguous().realize()
|
a = Tensor.ones(4).contiguous().realize()
|
||||||
@@ -136,16 +136,16 @@ class TestAssign(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(a.numpy(), 5)
|
np.testing.assert_allclose(a.numpy(), 5)
|
||||||
np.testing.assert_allclose(b.numpy(), 8)
|
np.testing.assert_allclose(b.numpy(), 8)
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_crossunder_assign(self):
|
def test_crossunder_assign(self):
|
||||||
a = Tensor.full((4,), 2).contiguous().realize()
|
with self.assertRaises(AssertionError):
|
||||||
b = Tensor.full((4,), 3).contiguous().realize()
|
a = Tensor.full((4,), 2).contiguous().realize()
|
||||||
c = a+9
|
b = Tensor.full((4,), 3).contiguous().realize()
|
||||||
a += b
|
c = a+9
|
||||||
b += c
|
a += b
|
||||||
Tensor.corealize([a,b])
|
b += c
|
||||||
np.testing.assert_allclose(a.numpy(), 2+3)
|
Tensor.corealize([a,b])
|
||||||
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
np.testing.assert_allclose(a.numpy(), 2+3)
|
||||||
|
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
||||||
|
|
||||||
def test_assign_kv_cache(self):
|
def test_assign_kv_cache(self):
|
||||||
bsz, max_context = 2, 8
|
bsz, max_context = 2, 8
|
||||||
|
|||||||
@@ -202,6 +202,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
|||||||
in_degree[x] -= 1
|
in_degree[x] -= 1
|
||||||
if in_degree[x] == 0: queue.append(x)
|
if in_degree[x] == 0: queue.append(x)
|
||||||
|
|
||||||
|
assert all(degree == 0 for degree in in_degree.values()), "Cycle detected in the graph"
|
||||||
# confirm everything was scheduled
|
# confirm everything was scheduled
|
||||||
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
|
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
|
||||||
return schedule
|
return schedule
|
||||||
|
|||||||
Reference in New Issue
Block a user