mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add rand_like test case with device specified (#7663)
in single device or copied multi case, device is applied. but for sharded case the device is silently ignored now. maybe similar to rand we just don't allow tuple device in rand_like
This commit is contained in:
@@ -569,24 +569,42 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_rand_like_on_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
t2 = Tensor.rand_like(t)
|
||||
assert t2.shape == t.shape
|
||||
assert t2.device == t.device
|
||||
assert t2.lazydata.axis == t.lazydata.axis
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
|
||||
t2 = Tensor.rand_like(t)
|
||||
assert t2.shape == t.shape
|
||||
assert t2.device == t.device
|
||||
assert t2.lazydata.axis == t.lazydata.axis
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs))
|
||||
|
||||
def test_rand_like_none_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
t2 = Tensor.rand_like(t)
|
||||
assert t2.shape == t.shape
|
||||
assert t2.device == t.device
|
||||
assert t2.lazydata.axis == t.lazydata.axis
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
def test_rand_like_arg_dtype(self):
|
||||
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
|
||||
t2 = Tensor.rand_like(t, dtype=dtypes.float32)
|
||||
self.assertEqual(t.dtype, dtypes.int32)
|
||||
self.assertEqual(t2.dtype, dtypes.float32)
|
||||
|
||||
def test_rand_like_arg_device(self):
|
||||
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
|
||||
t2 = Tensor.rand_like(t, device=(d3, d4))
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
# TODO: should no silently drop device
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
def test_dropout_on_shard(self):
|
||||
with Tensor.train():
|
||||
|
||||
Reference in New Issue
Block a user