From a53187eef7e33ddfc81bac3035343a8a13e70c73 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 9 Mar 2026 23:14:23 -0400 Subject: [PATCH] fix TestPartialAssignToSharedBuffer (#15202) bufferize_to_store issue with assign --- test/unit/test_assign.py | 31 +++++++++++++++++++++++++++++++ tinygrad/schedule/rangeify.py | 13 +++++++++---- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 32bdd52030..3e1228c6d2 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -911,5 +911,36 @@ class TestAssignToUnrealizedView(unittest.TestCase): # TODO: broken now, silently dropped self.assertEqual(c.tolist(), [[5,5],[5,5]]) +class TestPartialAssignToSharedBuffer(unittest.TestCase): + def test_five_slices(self): + big = Tensor.zeros(50).contiguous().realize() + views = [big[i*10:(i+1)*10].reshape(2, 5) for i in range(5)] + for v in views: v.assign(v + 1) + Tensor.realize(*views) + for v in views: + np.testing.assert_allclose(v.numpy(), np.ones((2, 5))) + + def test_many_slices(self): + n_params = 10 + big = Tensor.zeros(n_params * 12).contiguous().realize() + grads = [big[i*12:(i+1)*12].reshape(3, 4) for i in range(n_params)] + for g in grads: g.assign(g + 1) + Tensor.realize(*grads) + for g in grads: + np.testing.assert_allclose(g.numpy(), np.ones((3, 4))) + + def test_mixed_shapes(self): + big = Tensor.zeros(100).contiguous().realize() + shapes = [(3, 4), (4, 6), (6, 4), (2, 5), (4, 3)] + pos, views = 0, [] + for s in shapes: + n = s[0] * s[1] + views.append(big[pos:pos+n].reshape(*s)) + pos += n + for v in views: v.assign(v + 1) + Tensor.realize(*views) + for v, s in zip(views, shapes): + np.testing.assert_allclose(v.numpy(), np.ones(s)) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9a4dbd39ba..9be14b4573 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -359,10 +359,15 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): assign_target, assign_src = assign.src[0], assign.src[1] assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0] - # skip self-assign from same-device copy, otherwise create the store - # in assign, this is the buffer size, not the bufferize size - if assign_src is assign_target: ret = assign_target.src[0] - else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs)) + + store_target = assign_target + if assign.arg and assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX: + # BUFFERIZE(INDEX(...)); store through the underlying global index instead. + store_target = assign_target.src[0].src[0] + + end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg) + ret = store_target.buf_uop.base + if assign_src is not store_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs)) for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg) return ret