diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 456b2902f0..596e72aae6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -832,7 +832,7 @@ class Tensor(OpMixin): if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs) - # ***** rng hlops ***** + # ***** random functions ***** def randn_like(self, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: """ @@ -880,9 +880,8 @@ class Tensor(OpMixin): print(Tensor.randint(2, 3, low=5, high=10).numpy()) ``` """ - if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers") - dtype = to_dtype(dtype) - if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int") + if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers") + if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int") return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs) @staticmethod @@ -898,7 +897,7 @@ class Tensor(OpMixin): print(Tensor.normal(2, 3, mean=10, std=2).numpy()) ``` """ - return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad) + return (std * Tensor.randn(*shape, **kwargs) + mean).requires_grad_(requires_grad) @staticmethod def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: @@ -931,7 +930,6 @@ class Tensor(OpMixin): """ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5) - # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod def glorot_uniform(*shape, **kwargs) -> Tensor: """ @@ -945,9 +943,9 @@ class Tensor(OpMixin): print(Tensor.glorot_uniform(2, 3).numpy()) ``` """ - return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5) + bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5 + return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) - # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: """ @@ -961,10 +959,9 @@ class Tensor(OpMixin): print(Tensor.kaiming_uniform(2, 3).numpy()) ``` """ - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) + bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5 return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) - # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ @staticmethod def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: """ @@ -978,7 +975,7 @@ class Tensor(OpMixin): print(Tensor.kaiming_normal(2, 3).numpy()) ``` """ - std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) + std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5 return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) @staticmethod @@ -1059,7 +1056,7 @@ class Tensor(OpMixin): else: t.grad.assign(t.grad + g.to(t.grad.device)) return self - # ***** movement low level ops ***** + # ***** movement ops ***** def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) @@ -1137,8 +1134,6 @@ class Tensor(OpMixin): if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode) raise NotImplementedError(f"{mode=} is not supported") - # ***** movement high level ops ***** - def _getitem(self, indices, v: Tensor|None = None) -> Tensor: # wrap single index into a list if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]