mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
assert specifying device to rand_like a multi tensor (#7678)
* assert specifying device to rand_like a multi tensor raise RuntimeError instead of dropping it silently * fix that
This commit is contained in:
@@ -598,13 +598,15 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t2.dtype, dtypes.float32)
|
||||
|
||||
def test_rand_like_arg_device(self):
|
||||
# axis=None
|
||||
t = Tensor.empty((16, 16)).shard((d1, d2), axis=None)
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.rand_like(t, device=(d3, d4))
|
||||
|
||||
# axis=1
|
||||
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
|
||||
t2 = Tensor.rand_like(t, device=(d3, d4))
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
# TODO: should no silently drop device
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.rand_like(t, device=(d3, d4))
|
||||
|
||||
def test_dropout_on_shard(self):
|
||||
with Tensor.train():
|
||||
|
||||
@@ -709,10 +709,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
device = kwargs.pop("device", self.device)
|
||||
device_arg, device = kwargs.get("device"), kwargs.pop("device", self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
if isinstance(self.device, tuple):
|
||||
assert isinstance(self.lazydata, MultiLazyBuffer)
|
||||
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
|
||||
if device_arg is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.lazydata.axis is not None:
|
||||
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \
|
||||
for lb in self.lazydata.lbs]
|
||||
|
||||
Reference in New Issue
Block a user