minor rand_like change [run_process_replay] (#6848)

This commit is contained in:
chenyu
2024-10-02 07:27:51 -04:00
committed by GitHub
parent 7214450c23
commit 08850da026

View File

@@ -663,9 +663,8 @@ class Tensor:
if isinstance(self.device, tuple):
assert isinstance(self.lazydata, MultiLazyBuffer)
if self.lazydata.axis is not None:
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype) for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer([cast(LazyBuffer, r.lazydata) for r in rands], self.lazydata.axis, None),
device=self.device, dtype=dtype, **kwargs)
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs)