diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 975dbed4c3..6f62a4d1aa 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Union, Tuple, cast +from typing import Optional, Union, Tuple from tinygrad.tensor import Tensor from tinygrad.helpers import prod from tinygrad.nn import optim, state, datasets # noqa: F401 @@ -98,16 +98,13 @@ class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups - self.weight = self.initialize_weight(out_channels, in_channels, groups) - bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell - self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None + scale = 1 / math.sqrt(in_channels * prod(self.kernel_size)) + self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale) + self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None def __call__(self, x:Tensor): return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) - def initialize_weight(self, out_channels, in_channels, groups): - return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) - def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): """ Applies a 1D transposed convolution operator over an input signal composed of several input planes. @@ -144,15 +141,14 @@ class ConvTranspose2d(Conv2d): """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + scale = 1 / math.sqrt(in_channels * prod(self.kernel_size)) + self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale) self.output_padding = output_padding def __call__(self, x:Tensor): return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) - def initialize_weight(self, out_channels, in_channels, groups): - return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) - class Linear: """ Applies a linear transformation to the incoming data. @@ -170,9 +166,8 @@ class Linear: ``` """ def __init__(self, in_features, out_features, bias=True): - # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features)) - self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5)) bound = 1 / math.sqrt(in_features) + self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound) self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None def __call__(self, x:Tensor):