add rand_like test case with device specified (#7663)

in single device or copied multi case, device is applied. but for sharded case the device is silently ignored now. maybe similar to rand we just don't allow tuple device in rand_like
This commit is contained in:
chenyu
2024-11-13 09:32:55 -05:00
committed by GitHub
parent 23363dee55
commit 51432bfbff

View File

@@ -569,24 +569,42 @@ class TestMultiTensor(unittest.TestCase):
def test_rand_like_on_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
assert t2.shape == t.shape
assert t2.device == t.device
assert t2.lazydata.axis == t.lazydata.axis
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
t2 = Tensor.rand_like(t)
assert t2.shape == t.shape
assert t2.device == t.device
assert t2.lazydata.axis == t.lazydata.axis
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
assert t2.shape == t.shape
assert t2.device == t.device
assert t2.lazydata.axis == t.lazydata.axis
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_arg_dtype(self):
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
t2 = Tensor.rand_like(t, dtype=dtypes.float32)
self.assertEqual(t.dtype, dtypes.int32)
self.assertEqual(t2.dtype, dtypes.float32)
def test_rand_like_arg_device(self):
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
t2 = Tensor.rand_like(t, device=(d3, d4))
self.assertEqual(t.shape, t2.shape)
# TODO: should no silently drop device
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_dropout_on_shard(self):
with Tensor.train():