Added kaiming_uniform initialization for Conv2d and Linear layers (#756)

* added kaiming_uniform init for conv2d and linear layers

* fix: set getattr

* up

* fix: set getattr

* fix comments

* better does not mean it is good

* more nonlinearities

* added test

checks the distribution of default relu option

* prettier

* fix kernel size

* edit distribution of returned tensor

* complete tests and fix fan_mode

* added higher dim test

* prettier test

* fix silly blank

* just leaky_relu mode

* default fan in and leaky relu

* update params

* fix test

* shorter

* generalize Tensor.uniform and adjust kaiming init

- added low and high parameters to Tensor.uniform function, so it can have a specific range (default is 0 to 1)
- adjusted return line of kaiming_uniform

* range from -1 to 1

* delete comment

* adjusted test_uniform

* fixed

* delete comment
This commit is contained in:
Rabia Eda Yılmaz
2023-05-30 01:09:55 +03:00
committed by GitHub
parent 174c65b7d9
commit 3075988468
3 changed files with 21 additions and 9 deletions

View File

@@ -63,7 +63,7 @@ class TestRandomness(unittest.TestCase):
def test_uniform(self):
self.assertFalse(normal_test(Tensor.uniform))
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.rand(*x) * 2 - 1))
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.uniform(low=-1, high=1, size=x)))
def test_scaled_uniform(self):
self.assertFalse(normal_test(Tensor.scaled_uniform))
@@ -73,5 +73,13 @@ class TestRandomness(unittest.TestCase):
self.assertFalse(normal_test(Tensor.glorot_uniform))
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self, shape=(20, 23), a=0.01):
Tensor.manual_seed(1337)
torch.manual_seed(1337)
np.random.seed(1337)
bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:]))))
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape)))
if __name__ == "__main__":
unittest.main()

View File

@@ -34,12 +34,11 @@ class BatchNorm2d:
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
# TODO: is this good weight init?
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, initialization: str='kaiming_uniform'):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, *self.kernel_size)
self.weight = getattr(Tensor, initialization)(out_channels, in_channels//groups, *self.kernel_size)
self.bias = Tensor.zeros(out_channels) if bias else None
def __call__(self, x):
@@ -56,8 +55,8 @@ class ConvTranspose2d:
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.glorot_uniform(out_features, in_features)
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
self.weight = getattr(Tensor, initialization)(out_features, in_features)
self.bias = Tensor.zeros(out_features) if bias else None
def __call__(self, x):

View File

@@ -164,7 +164,7 @@ class Tensor:
# ***** rng hlops *****
@staticmethod
def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
@@ -173,8 +173,13 @@ class Tensor:
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
# ***** toposort and backward pass *****
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * prod(shape[2:]))
return Tensor.uniform(*shape, low=-bound, high=bound)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):
visited.add(node)
@@ -542,4 +547,4 @@ for device in Device._buffers:
from tinygrad.nn.image import image_conv2d, image_dot
if IMAGE:
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
setattr(Tensor, "dot", image_dot)