multi like on full_like as well as rand_like (#13402)

* multi like on full_like as well as rand_like

* add test and fix bug

* mismatch, optim match

* one line
This commit is contained in:
George Hotz
2025-11-20 20:46:48 -08:00
committed by GitHub
parent fa3def2f12
commit e1051d00d7
3 changed files with 26 additions and 14 deletions

View File

@@ -765,6 +765,16 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
def test_full_like_on_shard(self, axis=None):
t = Tensor.empty((16, 16)).shard(devices_2, axis=axis)
t2 = Tensor.full_like(t, 1.0)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
t2.realize()
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)

View File

@@ -2,7 +2,7 @@
import itertools
from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype
from tinygrad.dtype import dtypes, least_upper_dtype, to_dtype
class Optimizer:
"""
@@ -24,9 +24,9 @@ class Optimizer:
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
def _new_optim_param(self) -> list[Tensor]:
param_dtype = getenv("OPTIM_DTYPE", "float32")
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
def zero_grad(self):
"""

View File

@@ -128,7 +128,7 @@ class Tensor(OpMixin):
# create a UOp from the different types of inputs
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
@@ -300,6 +300,7 @@ class Tensor(OpMixin):
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
return self.replace(self._apply_uop(UOp.assign, x))
def detach(self) -> Tensor:
@@ -739,6 +740,14 @@ class Tensor(OpMixin):
t = (Tensor.arange(n, device=device).unsqueeze(-1) == Tensor.arange(m, device=device))
return t.cast(dtype or dtypes.default_float).requires_grad_(requires_grad)
def _multi_like(self, fxn, *args, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
@@ -752,6 +761,7 @@ class Tensor(OpMixin):
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, **kwargs)
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs) -> Tensor:
@@ -794,16 +804,8 @@ class Tensor(OpMixin):
print(Tensor.rand_like(t).numpy())
```
"""
dtype = kwargs.pop("dtype", self.dtype)
if isinstance(self.device, tuple):
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
if self.uop.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
contiguous = kwargs.pop("contiguous", True)
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
rands = UOp(Ops.MSTACK, dtype=dtype,
src=tuple([Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(rands, axis=self.uop.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** rng hlops *****