mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix TestPartialAssignToSharedBuffer (#15202)
bufferize_to_store issue with assign
This commit is contained in:
@@ -911,5 +911,36 @@ class TestAssignToUnrealizedView(unittest.TestCase):
|
||||
# TODO: broken now, silently dropped
|
||||
self.assertEqual(c.tolist(), [[5,5],[5,5]])
|
||||
|
||||
class TestPartialAssignToSharedBuffer(unittest.TestCase):
|
||||
def test_five_slices(self):
|
||||
big = Tensor.zeros(50).contiguous().realize()
|
||||
views = [big[i*10:(i+1)*10].reshape(2, 5) for i in range(5)]
|
||||
for v in views: v.assign(v + 1)
|
||||
Tensor.realize(*views)
|
||||
for v in views:
|
||||
np.testing.assert_allclose(v.numpy(), np.ones((2, 5)))
|
||||
|
||||
def test_many_slices(self):
|
||||
n_params = 10
|
||||
big = Tensor.zeros(n_params * 12).contiguous().realize()
|
||||
grads = [big[i*12:(i+1)*12].reshape(3, 4) for i in range(n_params)]
|
||||
for g in grads: g.assign(g + 1)
|
||||
Tensor.realize(*grads)
|
||||
for g in grads:
|
||||
np.testing.assert_allclose(g.numpy(), np.ones((3, 4)))
|
||||
|
||||
def test_mixed_shapes(self):
|
||||
big = Tensor.zeros(100).contiguous().realize()
|
||||
shapes = [(3, 4), (4, 6), (6, 4), (2, 5), (4, 3)]
|
||||
pos, views = 0, []
|
||||
for s in shapes:
|
||||
n = s[0] * s[1]
|
||||
views.append(big[pos:pos+n].reshape(*s))
|
||||
pos += n
|
||||
for v in views: v.assign(v + 1)
|
||||
Tensor.realize(*views)
|
||||
for v, s in zip(views, shapes):
|
||||
np.testing.assert_allclose(v.numpy(), np.ones(s))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -359,10 +359,15 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
assign_target, assign_src = assign.src[0], assign.src[1]
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
|
||||
# skip self-assign from same-device copy, otherwise create the store
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
if assign_src is assign_target: ret = assign_target.src[0]
|
||||
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
|
||||
|
||||
store_target = assign_target
|
||||
if assign.arg and assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
|
||||
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
|
||||
store_target = assign_target.src[0].src[0]
|
||||
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
ret = store_target.buf_uop.base
|
||||
if assign_src is not store_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
|
||||
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user