block multistore, it's not supported (#15708)

This commit is contained in:
George Hotz
2026-04-13 20:57:59 +08:00
committed by GitHub
parent 84d64b5835
commit 7610bdc59e
2 changed files with 24 additions and 4 deletions

View File

@@ -958,10 +958,27 @@ class TestAfterCachePatterns(unittest.TestCase):
a_store = a.uop.store(c.uop)
b_store = b.uop.store(c.uop)
a = Tensor(a.uop.after(a_store, b_store))
a.realize()
np.testing.assert_array_equal(a.numpy(), 1)
np.testing.assert_array_equal(b.numpy(), 1)
with self.assertRaises(AssertionError):
a = Tensor(a.uop.after(a_store, b_store))
a.realize()
np.testing.assert_array_equal(a.numpy(), 1)
np.testing.assert_array_equal(b.numpy(), 1)
def test_double_store_after_different_sizes(self):
full = Tensor.zeros(2).contiguous()
head = Tensor.zeros(1).contiguous()
full_src = Tensor([1, 2], dtype=dtypes.float).contiguous()
head_src = Tensor([3], dtype=dtypes.float).contiguous()
Tensor.realize(full, head, full_src, head_src)
full_store = full.uop.store(full_src.uop)
head_store = head.uop.store(head_src.uop)
with self.assertRaises(AssertionError):
head = Tensor(head.uop.after(head_store, full_store))
head.realize()
np.testing.assert_array_equal(head.numpy(), [3])
np.testing.assert_array_equal(full.numpy(), [1, 2])
if __name__ == "__main__":
unittest.main()

View File

@@ -391,6 +391,9 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
if (after:=x.src[0]).op is Ops.AFTER:
buf = after.src[0].buf_uop.base
if not (stores := [s for s in after.src[1:] if s.op is Ops.STORE and s.src[0].op is Ops.INDEX]): return buf
# the ranges are created on the AFTER, and the stores might be different sizes
# so we block all multi store AFTERs
assert len(stores) <= 1, "rangeify doesn't support multiple stores on one after"
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
ended_stores = []
for store in stores: