hotfix: RuntimeError for assign

This commit is contained in:
George Hotz
2024-03-27 11:18:48 -07:00
parent 9fb573d73c
commit 60639cccac
2 changed files with 7 additions and 5 deletions

View File

@@ -93,7 +93,8 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond(self):
with self.assertRaises(AssertionError):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaises(RuntimeError):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
@@ -137,7 +138,8 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), 8)
def test_crossunder_assign(self):
with self.assertRaises(AssertionError):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaises(RuntimeError):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9

View File

@@ -202,7 +202,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
assert all(degree == 0 for degree in in_degree.values()), "Cycle detected in the graph"
# confirm everything was scheduled
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
return schedule