clean up tensor random functions (#15648)

* clean up tensor random functions

* revert that
This commit is contained in:
chenyu
2026-04-08 09:44:37 -04:00
committed by GitHub
parent 1ebeb52e59
commit dae9dea903

View File

@@ -832,7 +832,7 @@ class Tensor(OpMixin):
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** rng hlops *****
# ***** random functions *****
def randn_like(self, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
"""
@@ -880,9 +880,8 @@ class Tensor(OpMixin):
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
dtype = to_dtype(dtype)
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
@@ -898,7 +897,7 @@ class Tensor(OpMixin):
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
return (std * Tensor.randn(*shape, **kwargs) + mean).requires_grad_(requires_grad)
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
@@ -931,7 +930,6 @@ class Tensor(OpMixin):
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
@@ -945,9 +943,9 @@ class Tensor(OpMixin):
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
@@ -961,10 +959,9 @@ class Tensor(OpMixin):
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
@@ -978,7 +975,7 @@ class Tensor(OpMixin):
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
@@ -1059,7 +1056,7 @@ class Tensor(OpMixin):
else: t.grad.assign(t.grad + g.to(t.grad.device))
return self
# ***** movement low level ops *****
# ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
@@ -1137,8 +1134,6 @@ class Tensor(OpMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
# ***** movement high level ops *****
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
# wrap single index into a list
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]