mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix clone for multi (#14919)
also update empty_like to make sure it's backed by buffers
This commit is contained in:
@@ -840,13 +840,15 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t.shard_(devices, axis=0).realize()
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
|
||||
|
||||
@unittest.skip("this is unreliable on OSX")
|
||||
def test_clone(self):
|
||||
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
|
||||
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
|
||||
|
||||
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
|
||||
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
|
||||
for axis in (None, 0):
|
||||
t = Tensor.arange(16).reshape(4, 4).shard(devices_2, axis=axis).contiguous().realize()
|
||||
t_clone = t.clone().realize()
|
||||
self.assertEqual(t_clone.device, t.device)
|
||||
self.assertEqual(t_clone.uop.axis, axis)
|
||||
self.assertEqual(t_clone.tolist(), t.tolist())
|
||||
t_clone += 1
|
||||
self.assertNotEqual(t_clone.tolist(), t.tolist())
|
||||
|
||||
@unittest.skip("RANGEIFY doesn't support multi const folding")
|
||||
def test_multi_const_folding(self):
|
||||
|
||||
@@ -72,6 +72,7 @@ class TestMultiAxis(unittest.TestCase):
|
||||
self.assertEqual(e.shape, t.shape)
|
||||
self.assertEqual(e.device, t.device)
|
||||
self.assertEqual(e.uop.axis, 0)
|
||||
self.assertTrue(e.uop.has_buffer_identity())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -404,7 +404,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Creates a clone of this tensor allocating a separate buffer for the data.
|
||||
"""
|
||||
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
|
||||
ret = self.empty_like()
|
||||
if self.grad is not None: ret.grad = self.grad.clone()
|
||||
return ret.assign(self)
|
||||
|
||||
@@ -544,7 +544,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
|
||||
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
|
||||
return Tensor.empty(self.shape, dtype=dtype, device=device[0], **kwargs).shard(device, axis)
|
||||
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
|
||||
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user