mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
simpler rand_like (#7680)
This commit is contained in:
@@ -258,6 +258,11 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert a.dtype == dtypes.default_int and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
def test_rand_like_device(self):
|
||||
a = Tensor.ones(3, 3, device="CLANG")
|
||||
b = Tensor.rand_like(a)
|
||||
self.assertEqual(b.device, a.device)
|
||||
|
||||
def test_ndim(self):
|
||||
assert Tensor(1).ndim == 0
|
||||
assert Tensor.randn(1).ndim == 1
|
||||
|
||||
@@ -709,16 +709,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
device_arg, device = kwargs.get("device"), kwargs.pop("device", self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
|
||||
if device_arg is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.lazydata.axis is not None:
|
||||
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \
|
||||
for lb in self.lazydata.lbs]
|
||||
return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
return Tensor.rand(*self.shape, dtype=dtype, contiguous=contiguous, **kwargs).shard(self.device)
|
||||
return Tensor.rand(*self.shape, device=device, dtype=dtype, contiguous=contiguous, **kwargs)
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
|
||||
return Tensor(MultiLazyBuffer(cast(List[LazyBuffer], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user