fix TestPartialAssignToSharedBuffer (#15202)

bufferize_to_store issue with assign
This commit is contained in:
chenyu
2026-03-09 23:14:23 -04:00
committed by GitHub
parent 525a178966
commit a53187eef7
2 changed files with 40 additions and 4 deletions

View File

@@ -911,5 +911,36 @@ class TestAssignToUnrealizedView(unittest.TestCase):
# TODO: broken now, silently dropped
self.assertEqual(c.tolist(), [[5,5],[5,5]])
class TestPartialAssignToSharedBuffer(unittest.TestCase):
def test_five_slices(self):
big = Tensor.zeros(50).contiguous().realize()
views = [big[i*10:(i+1)*10].reshape(2, 5) for i in range(5)]
for v in views: v.assign(v + 1)
Tensor.realize(*views)
for v in views:
np.testing.assert_allclose(v.numpy(), np.ones((2, 5)))
def test_many_slices(self):
n_params = 10
big = Tensor.zeros(n_params * 12).contiguous().realize()
grads = [big[i*12:(i+1)*12].reshape(3, 4) for i in range(n_params)]
for g in grads: g.assign(g + 1)
Tensor.realize(*grads)
for g in grads:
np.testing.assert_allclose(g.numpy(), np.ones((3, 4)))
def test_mixed_shapes(self):
big = Tensor.zeros(100).contiguous().realize()
shapes = [(3, 4), (4, 6), (6, 4), (2, 5), (4, 3)]
pos, views = 0, []
for s in shapes:
n = s[0] * s[1]
views.append(big[pos:pos+n].reshape(*s))
pos += n
for v in views: v.assign(v + 1)
Tensor.realize(*views)
for v, s in zip(views, shapes):
np.testing.assert_allclose(v.numpy(), np.ones(s))
if __name__ == "__main__":
unittest.main()

View File

@@ -359,10 +359,15 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
assign_target, assign_src = assign.src[0], assign.src[1]
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
# skip self-assign from same-device copy, otherwise create the store
# in assign, this is the buffer size, not the bufferize size
if assign_src is assign_target: ret = assign_target.src[0]
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
store_target = assign_target
if assign.arg and assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
store_target = assign_target.src[0].src[0]
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
ret = store_target.buf_uop.base
if assign_src is not store_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
return ret