mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add nn layer tests
This commit is contained in:
@@ -3,6 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad.tensor import Tensor
|
||||
import tinygrad.nn as nn
|
||||
|
||||
# https://gist.github.com/devries/11405101
|
||||
def ksprob(a):
|
||||
@@ -80,5 +81,20 @@ class TestRandomness(unittest.TestCase):
|
||||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
|
||||
|
||||
def test_conv2d_init(self):
|
||||
params = (128, 256, (3,3))
|
||||
assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())
|
||||
assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach())
|
||||
|
||||
def test_linear_init(self):
|
||||
params = (64, 64)
|
||||
assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach())
|
||||
assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach())
|
||||
|
||||
def test_bn_init(self):
|
||||
params = (64,)
|
||||
assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach())
|
||||
assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user