mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up tensor random functions (#15648)
* clean up tensor random functions * revert that
This commit is contained in:
@@ -832,7 +832,7 @@ class Tensor(OpMixin):
|
||||
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
# ***** random functions *****
|
||||
|
||||
def randn_like(self, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -880,9 +880,8 @@ class Tensor(OpMixin):
|
||||
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
dtype = to_dtype(dtype)
|
||||
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
||||
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@@ -898,7 +897,7 @@ class Tensor(OpMixin):
|
||||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||
```
|
||||
"""
|
||||
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
|
||||
return (std * Tensor.randn(*shape, **kwargs) + mean).requires_grad_(requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
||||
@@ -931,7 +930,6 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
|
||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -945,9 +943,9 @@ class Tensor(OpMixin):
|
||||
print(Tensor.glorot_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
||||
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -961,10 +959,9 @@ class Tensor(OpMixin):
|
||||
print(Tensor.kaiming_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -978,7 +975,7 @@ class Tensor(OpMixin):
|
||||
print(Tensor.kaiming_normal(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@@ -1059,7 +1056,7 @@ class Tensor(OpMixin):
|
||||
else: t.grad.assign(t.grad + g.to(t.grad.device))
|
||||
return self
|
||||
|
||||
# ***** movement low level ops *****
|
||||
# ***** movement ops *****
|
||||
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
@@ -1137,8 +1134,6 @@ class Tensor(OpMixin):
|
||||
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
||||
raise NotImplementedError(f"{mode=} is not supported")
|
||||
|
||||
# ***** movement high level ops *****
|
||||
|
||||
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
||||
# wrap single index into a list
|
||||
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
||||
|
||||
Reference in New Issue
Block a user