From 4d28d5568396c1fbae6f6ee20998396bc54b30cd Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 1 Jun 2023 21:34:24 -0700 Subject: [PATCH] add nn layer tests --- test/test_randomness.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_randomness.py b/test/test_randomness.py index 203f15577a..f1e6a008c4 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -3,6 +3,7 @@ import unittest import numpy as np import torch from tinygrad.tensor import Tensor +import tinygrad.nn as nn # https://gist.github.com/devries/11405101 def ksprob(a): @@ -80,5 +81,20 @@ class TestRandomness(unittest.TestCase): for shape in [(128, 64, 3, 3), (20, 24)]: self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) + def test_conv2d_init(self): + params = (128, 256, (3,3)) + assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach()) + + def test_linear_init(self): + params = (64, 64) + assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach()) + + def test_bn_init(self): + params = (64,) + assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach()) + assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach()) + if __name__ == "__main__": unittest.main()