diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 771a6cd5f3..fe947394fc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -663,9 +663,8 @@ class Tensor: if isinstance(self.device, tuple): assert isinstance(self.lazydata, MultiLazyBuffer) if self.lazydata.axis is not None: - rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype) for lb in self.lazydata.lbs] - return Tensor(MultiLazyBuffer([cast(LazyBuffer, r.lazydata) for r in rands], self.lazydata.axis, None), - device=self.device, dtype=dtype, **kwargs) + rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs] + return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs)