rng hlops: add normal and kaiming_normal (#1378)

* add normal and kaiming_normal

* make sure its float

* add tests
This commit is contained in:
JaSpa99
2023-07-31 19:37:02 +02:00
committed by GitHub
parent 37fa7e96fb
commit 5ab12059da
3 changed files with 21 additions and 1 deletions

View File

@@ -178,6 +178,9 @@ class Tensor:
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
@@ -194,6 +197,12 @@ class Tensor:
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):