diff --git a/test/backend/test_stunning.py b/test/backend/test_stunning.py index 4d9e966a77..28c4499946 100644 --- a/test/backend/test_stunning.py +++ b/test/backend/test_stunning.py @@ -25,7 +25,7 @@ class TestStunning(unittest.TestCase): nv = a[12].cat(a[76]).tolist() vi = Variable('i', 0, a.shape[0]-1) - with self.assertRaisesRegex(AssertionError, "bind mismatch on"): + with self.assertRaisesRegex(RuntimeError, "bind mismatch on"): wv = a[vi.bind(12)].cat(a[vi.bind(76)]).tolist() self.assertListEqual(nv, wv) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e2ae4dfbe3..cbaa9f0848 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -148,7 +148,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di nm = b.src[0].expr if nm not in used_vars: continue val = b.src[1].arg - assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" + if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}") var_vals[nm] = val # convert LINEAR to ExecItems