diff --git a/test/test_randomness.py b/test/test_randomness.py index 203f15577a..f1e6a008c4 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -3,6 +3,7 @@ import unittest import numpy as np import torch from tinygrad.tensor import Tensor +import tinygrad.nn as nn # https://gist.github.com/devries/11405101 def ksprob(a): @@ -80,5 +81,20 @@ class TestRandomness(unittest.TestCase): for shape in [(128, 64, 3, 3), (20, 24)]: self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) + def test_conv2d_init(self): + params = (128, 256, (3,3)) + assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach()) + + def test_linear_init(self): + params = (64, 64) + assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach()) + + def test_bn_init(self): + params = (64,) + assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach()) + if __name__ == "__main__": unittest.main()