mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
rng hlops: add normal and kaiming_normal (#1378)
* add normal and kaiming_normal * make sure its float * add tests
This commit is contained in:
@@ -178,6 +178,9 @@ class Tensor:
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
|
||||
@@ -194,6 +197,12 @@ class Tensor:
|
||||
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
def deepwalk(self):
|
||||
def _deepwalk(node, visited, nodes):
|
||||
|
||||
Reference in New Issue
Block a user