assert specifying device to rand_like a multi tensor (#7678)

* assert specifying device to rand_like a multi tensor

raise RuntimeError instead of dropping it silently

* fix that
This commit is contained in:
chenyu
2024-11-13 10:24:40 -05:00
committed by GitHub
parent 51432bfbff
commit d1dfd598a2
2 changed files with 11 additions and 9 deletions

View File

@@ -598,13 +598,15 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t2.dtype, dtypes.float32)
def test_rand_like_arg_device(self):
# axis=None
t = Tensor.empty((16, 16)).shard((d1, d2), axis=None)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
# axis=1
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
t2 = Tensor.rand_like(t, device=(d3, d4))
self.assertEqual(t.shape, t2.shape)
# TODO: should no silently drop device
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
def test_dropout_on_shard(self):
with Tensor.train():