From d1dfd598a2d96cfcb9cbed3d6235508852e63880 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Nov 2024 10:24:40 -0500 Subject: [PATCH] assert specifying device to rand_like a multi tensor (#7678) * assert specifying device to rand_like a multi tensor raise RuntimeError instead of dropping it silently * fix that --- test/test_multitensor.py | 14 ++++++++------ tinygrad/tensor.py | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6c15186ac3..5b1704c364 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -598,13 +598,15 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t2.dtype, dtypes.float32) def test_rand_like_arg_device(self): + # axis=None + t = Tensor.empty((16, 16)).shard((d1, d2), axis=None) + with self.assertRaises(RuntimeError): + Tensor.rand_like(t, device=(d3, d4)) + + # axis=1 t = Tensor.empty((16, 16)).shard((d1, d2), axis=1) - t2 = Tensor.rand_like(t, device=(d3, d4)) - self.assertEqual(t.shape, t2.shape) - # TODO: should no silently drop device - self.assertEqual(t.device, t2.device) - self.assertEqual(t.dtype, t2.dtype) - self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + with self.assertRaises(RuntimeError): + Tensor.rand_like(t, device=(d3, d4)) def test_dropout_on_shard(self): with Tensor.train(): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 05ab20fb4f..530aba92da 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -709,10 +709,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ dtype = kwargs.pop("dtype", self.dtype) - device = kwargs.pop("device", self.device) + device_arg, device = kwargs.get("device"), kwargs.pop("device", self.device) contiguous = kwargs.pop("contiguous", True) - if isinstance(self.device, tuple): - assert isinstance(self.lazydata, MultiLazyBuffer) + if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer): + if device_arg is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is not None: rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \ for lb in self.lazydata.lbs]