From 24286c5593820fcd00dbd1986a94f6901575e400 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 20 Feb 2026 17:21:09 -0500 Subject: [PATCH] fix clone for multi (#14919) also update empty_like to make sure it's backed by buffers --- test/backend/test_multitensor.py | 14 ++++++++------ test/null/test_multitensor.py | 1 + tinygrad/tensor.py | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index 2f59adaa0b..823bedddea 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -840,13 +840,15 @@ class TestMultiTensor(unittest.TestCase): t.shard_(devices, axis=0).realize() assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src]) - @unittest.skip("this is unreliable on OSX") def test_clone(self): - t = Tensor.rand(16, 16).shard(devices_2, axis=None) - np.testing.assert_allclose(t.numpy(), t.clone().numpy()) - - t = Tensor.rand(16, 16).shard(devices_2, axis=0) - np.testing.assert_allclose(t.numpy(), t.clone().numpy()) + for axis in (None, 0): + t = Tensor.arange(16).reshape(4, 4).shard(devices_2, axis=axis).contiguous().realize() + t_clone = t.clone().realize() + self.assertEqual(t_clone.device, t.device) + self.assertEqual(t_clone.uop.axis, axis) + self.assertEqual(t_clone.tolist(), t.tolist()) + t_clone += 1 + self.assertNotEqual(t_clone.tolist(), t.tolist()) @unittest.skip("RANGEIFY doesn't support multi const folding") def test_multi_const_folding(self): diff --git a/test/null/test_multitensor.py b/test/null/test_multitensor.py index c8ddaecdd5..de3590dcfa 100644 --- a/test/null/test_multitensor.py +++ b/test/null/test_multitensor.py @@ -72,6 +72,7 @@ class TestMultiAxis(unittest.TestCase): self.assertEqual(e.shape, t.shape) self.assertEqual(e.device, t.device) self.assertEqual(e.uop.axis, 0) + self.assertTrue(e.uop.has_buffer_identity()) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 38eb3d116c..3c0e1fc160 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -404,7 +404,7 @@ class Tensor(OpMixin): """ Creates a clone of this tensor allocating a separate buffer for the data. """ - ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype) + ret = self.empty_like() if self.grad is not None: ret.grad = self.grad.clone() return ret.assign(self) @@ -544,7 +544,7 @@ class Tensor(OpMixin): """ dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device if isinstance(device, tuple) and (axis := self.uop.axis) is not None: - return Tensor.empty(self.shape, dtype=dtype, device=device[0], **kwargs).shard(device, axis) + return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device) return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs) @staticmethod