mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Added kaiming_uniform initialization for Conv2d and Linear layers (#756)
* added kaiming_uniform init for conv2d and linear layers * fix: set getattr * up * fix: set getattr * fix comments * better does not mean it is good * more nonlinearities * added test checks the distribution of default relu option * prettier * fix kernel size * edit distribution of returned tensor * complete tests and fix fan_mode * added higher dim test * prettier test * fix silly blank * just leaky_relu mode * default fan in and leaky relu * update params * fix test * shorter * generalize Tensor.uniform and adjust kaiming init - added low and high parameters to Tensor.uniform function, so it can have a specific range (default is 0 to 1) - adjusted return line of kaiming_uniform * range from -1 to 1 * delete comment * adjusted test_uniform * fixed * delete comment
This commit is contained in:
@@ -63,7 +63,7 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
def test_uniform(self):
|
||||
self.assertFalse(normal_test(Tensor.uniform))
|
||||
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.rand(*x) * 2 - 1))
|
||||
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.uniform(low=-1, high=1, size=x)))
|
||||
|
||||
def test_scaled_uniform(self):
|
||||
self.assertFalse(normal_test(Tensor.scaled_uniform))
|
||||
@@ -73,5 +73,13 @@ class TestRandomness(unittest.TestCase):
|
||||
self.assertFalse(normal_test(Tensor.glorot_uniform))
|
||||
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
|
||||
|
||||
def test_kaiming_uniform(self, shape=(20, 23), a=0.01):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
|
||||
bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:]))))
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -34,12 +34,11 @@ class BatchNorm2d:
|
||||
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
||||
|
||||
# TODO: is this good weight init?
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, initialization: str='kaiming_uniform'):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, *self.kernel_size)
|
||||
self.weight = getattr(Tensor, initialization)(out_channels, in_channels//groups, *self.kernel_size)
|
||||
self.bias = Tensor.zeros(out_channels) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
@@ -56,8 +55,8 @@ class ConvTranspose2d:
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.glorot_uniform(out_features, in_features)
|
||||
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
|
||||
self.weight = getattr(Tensor, initialization)(out_features, in_features)
|
||||
self.bias = Tensor.zeros(out_features) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
|
||||
@@ -164,7 +164,7 @@ class Tensor:
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
|
||||
@@ -173,8 +173,13 @@ class Tensor:
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * prod(shape[2:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
def deepwalk(self):
|
||||
def _deepwalk(node, visited, nodes):
|
||||
visited.add(node)
|
||||
@@ -542,4 +547,4 @@ for device in Device._buffers:
|
||||
from tinygrad.nn.image import image_conv2d, image_dot
|
||||
if IMAGE:
|
||||
setattr(Tensor, "conv2d", image_conv2d)
|
||||
setattr(Tensor, "dot", image_dot)
|
||||
setattr(Tensor, "dot", image_dot)
|
||||
Reference in New Issue
Block a user