mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
multi like on full_like as well as rand_like (#13402)
* multi like on full_like as well as rand_like * add test and fix bug * mismatch, optim match * one line
This commit is contained in:
@@ -765,6 +765,16 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.rand_like(t, device=(d3, d4))
|
||||
|
||||
def test_full_like_on_shard(self, axis=None):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2, axis=axis)
|
||||
t2 = Tensor.full_like(t, 1.0)
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.uop.axis, t2.uop.axis)
|
||||
t2.realize()
|
||||
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
|
||||
|
||||
def test_dropout_on_shard(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).to(devices_2)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype, to_dtype
|
||||
|
||||
class Optimizer:
|
||||
"""
|
||||
@@ -24,9 +24,9 @@ class Optimizer:
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = getenv("OPTIM_DTYPE", "float32")
|
||||
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
||||
return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
|
||||
@@ -128,7 +128,7 @@ class Tensor(OpMixin):
|
||||
|
||||
# create a UOp from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
|
||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
|
||||
@@ -300,6 +300,7 @@ class Tensor(OpMixin):
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
@@ -739,6 +740,14 @@ class Tensor(OpMixin):
|
||||
t = (Tensor.arange(n, device=device).unsqueeze(-1) == Tensor.arange(m, device=device))
|
||||
return t.cast(dtype or dtypes.default_float).requires_grad_(requires_grad)
|
||||
|
||||
def _multi_like(self, fxn, *args, **kwargs) -> Tensor:
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
|
||||
|
||||
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||
@@ -752,6 +761,7 @@ class Tensor(OpMixin):
|
||||
print(Tensor.full_like(t, 42).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, **kwargs)
|
||||
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
|
||||
def zeros_like(self, **kwargs) -> Tensor:
|
||||
@@ -794,16 +804,8 @@ class Tensor(OpMixin):
|
||||
print(Tensor.rand_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
|
||||
rands = UOp(Ops.MSTACK, dtype=dtype,
|
||||
src=tuple([Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(rands, axis=self.uop.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
||||
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user