From e1051d00d7d26646586ebe624c846ac0b55839cf Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:46:48 -0800 Subject: [PATCH] multi like on full_like as well as rand_like (#13402) * multi like on full_like as well as rand_like * add test and fix bug * mismatch, optim match * one line --- test/test_multitensor.py | 10 ++++++++++ tinygrad/nn/optim.py | 6 +++--- tinygrad/tensor.py | 24 +++++++++++++----------- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index bca243c5e5..609144c4e1 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -765,6 +765,16 @@ class TestMultiTensor(unittest.TestCase): with self.assertRaises(RuntimeError): Tensor.rand_like(t, device=(d3, d4)) + def test_full_like_on_shard(self, axis=None): + t = Tensor.empty((16, 16)).shard(devices_2, axis=axis) + t2 = Tensor.full_like(t, 1.0) + self.assertEqual(t.shape, t2.shape) + self.assertEqual(t.device, t2.device) + self.assertEqual(t.dtype, t2.dtype) + self.assertEqual(t.uop.axis, t2.uop.axis) + t2.realize() + def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) + def test_dropout_on_shard(self): with Tensor.train(): X = Tensor.ones(256).to(devices_2) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 53cb043acb..8cd5fe6984 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -2,7 +2,7 @@ import itertools from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM from tinygrad.tensor import Tensor -from tinygrad.dtype import dtypes, least_upper_dtype +from tinygrad.dtype import dtypes, least_upper_dtype, to_dtype class Optimizer: """ @@ -24,9 +24,9 @@ class Optimizer: if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0)) def _new_optim_param(self) -> list[Tensor]: - param_dtype = getenv("OPTIM_DTYPE", "float32") + param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32")) if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()] - return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params] + return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params] def zero_grad(self): """ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f6ed8fad66..1c1b74e9ee 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -128,7 +128,7 @@ class Tensor(OpMixin): # create a UOp from the different types of inputs if isinstance(data, UOp): - assert _dtype is None or _dtype==data.dtype, "dtype doesn't match, and casting isn't supported" + assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported" # if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of if data.dtype==dtypes.index: data = _index_to_concrete_int(data) if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here @@ -300,6 +300,7 @@ class Tensor(OpMixin): assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" + assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}" return self.replace(self._apply_uop(UOp.assign, x)) def detach(self) -> Tensor: @@ -739,6 +740,14 @@ class Tensor(OpMixin): t = (Tensor.arange(n, device=device).unsqueeze(-1) == Tensor.arange(m, device=device)) return t.cast(dtype or dtypes.default_float).requires_grad_(requires_grad) + def _multi_like(self, fxn, *args, **kwargs) -> Tensor: + dtype = kwargs.pop("dtype", self.dtype) + if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor") + if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device) + sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape)) + stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])) + return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype) + def full_like(self, fill_value:ConstType, **kwargs) -> Tensor: """ Creates a tensor with the same shape as `self`, filled with the given value. @@ -752,6 +761,7 @@ class Tensor(OpMixin): print(Tensor.full_like(t, 42).numpy()) ``` """ + if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, **kwargs) return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs) -> Tensor: @@ -794,16 +804,8 @@ class Tensor(OpMixin): print(Tensor.rand_like(t).numpy()) ``` """ - dtype = kwargs.pop("dtype", self.dtype) - if isinstance(self.device, tuple): - if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") - if self.uop.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) - contiguous = kwargs.pop("contiguous", True) - sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape)) - rands = UOp(Ops.MSTACK, dtype=dtype, - src=tuple([Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).uop for d in self.device])) - return Tensor(UOp.multi(rands, axis=self.uop.axis), device=self.device, dtype=dtype, **kwargs) - return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) + if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs) + return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs) # ***** rng hlops *****