diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 45591d94f0..c83158cf61 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -958,10 +958,27 @@ class TestAfterCachePatterns(unittest.TestCase): a_store = a.uop.store(c.uop) b_store = b.uop.store(c.uop) - a = Tensor(a.uop.after(a_store, b_store)) - a.realize() - np.testing.assert_array_equal(a.numpy(), 1) - np.testing.assert_array_equal(b.numpy(), 1) + with self.assertRaises(AssertionError): + a = Tensor(a.uop.after(a_store, b_store)) + a.realize() + np.testing.assert_array_equal(a.numpy(), 1) + np.testing.assert_array_equal(b.numpy(), 1) + + def test_double_store_after_different_sizes(self): + full = Tensor.zeros(2).contiguous() + head = Tensor.zeros(1).contiguous() + full_src = Tensor([1, 2], dtype=dtypes.float).contiguous() + head_src = Tensor([3], dtype=dtypes.float).contiguous() + Tensor.realize(full, head, full_src, head_src) + + full_store = full.uop.store(full_src.uop) + head_store = head.uop.store(head_src.uop) + + with self.assertRaises(AssertionError): + head = Tensor(head.uop.after(head_store, full_store)) + head.realize() + np.testing.assert_array_equal(head.numpy(), [3]) + np.testing.assert_array_equal(full.numpy(), [1, 2]) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 38e57b0080..c086747e5d 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -391,6 +391,9 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): if (after:=x.src[0]).op is Ops.AFTER: buf = after.src[0].buf_uop.base if not (stores := [s for s in after.src[1:] if s.op is Ops.STORE and s.src[0].op is Ops.INDEX]): return buf + # the ranges are created on the AFTER, and the stores might be different sizes + # so we block all multi store AFTERs + assert len(stores) <= 1, "rangeify doesn't support multiple stores on one after" # BUFFERIZE(INDEX(...)); store through the underlying global index instead. ended_stores = [] for store in stores: