import math from tinygrad import Tensor, nn from tinygrad.helpers import prod, argfix # rejection sampling truncated randn def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor: CNT=8 x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs) ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape) take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True) # set to 0 if no good samples return (ctr == take).where(x, 0).sum(axis=-1) # https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065 def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor: std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978 return std * rand_truncn(*shape, **kwargs) class Conv2dHeNormal(nn.Conv2d): def initialize_weight(self, out_channels, in_channels, groups): return he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0) class Linear(nn.Linear): def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features, bias=bias) self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01) if bias: self.bias = Tensor.zeros(out_features)