fix clone for multi (#14919)

also update empty_like to make sure it's backed by buffers
This commit is contained in:
chenyu
2026-02-20 17:21:09 -05:00
committed by GitHub
parent 1fc1508f67
commit 24286c5593
3 changed files with 11 additions and 8 deletions

View File

@@ -840,13 +840,15 @@ class TestMultiTensor(unittest.TestCase):
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
@unittest.skip("this is unreliable on OSX")
def test_clone(self):
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
for axis in (None, 0):
t = Tensor.arange(16).reshape(4, 4).shard(devices_2, axis=axis).contiguous().realize()
t_clone = t.clone().realize()
self.assertEqual(t_clone.device, t.device)
self.assertEqual(t_clone.uop.axis, axis)
self.assertEqual(t_clone.tolist(), t.tolist())
t_clone += 1
self.assertNotEqual(t_clone.tolist(), t.tolist())
@unittest.skip("RANGEIFY doesn't support multi const folding")
def test_multi_const_folding(self):

View File

@@ -72,6 +72,7 @@ class TestMultiAxis(unittest.TestCase):
self.assertEqual(e.shape, t.shape)
self.assertEqual(e.device, t.device)
self.assertEqual(e.uop.axis, 0)
self.assertTrue(e.uop.has_buffer_identity())
if __name__ == '__main__':
unittest.main()

View File

@@ -404,7 +404,7 @@ class Tensor(OpMixin):
"""
Creates a clone of this tensor allocating a separate buffer for the data.
"""
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
ret = self.empty_like()
if self.grad is not None: ret.grad = self.grad.clone()
return ret.assign(self)
@@ -544,7 +544,7 @@ class Tensor(OpMixin):
"""
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
return Tensor.empty(self.shape, dtype=dtype, device=device[0], **kwargs).shard(device, axis)
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
@staticmethod