unit test for batchnorm2d

This commit is contained in:
George Hotz
2020-10-30 08:19:58 -07:00
parent 843b1cb7d6
commit c14473f87d
2 changed files with 43 additions and 0 deletions

View File

@@ -125,6 +125,9 @@ if __name__ == "__main__":
print(img.shape, img.dtype)
# run the net
import time
st = time.time()
out = model.forward(Tensor(img))
print("did inference in %.2f s" % (time.time()-st))
print(np.argmax(out.data), np.max(out.data))

40
test/test_nn.py Normal file
View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.nn import *
import torch
class TestNN(unittest.TestCase):
def test_batchnorm2d(self):
sz = 4
# create in tinygrad
bn = BatchNorm2D(sz, eps=1e-5)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.data[bn.running_var.data < 0] = 0
# create in torch
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.weight[:] = torch.tensor(bn.weight.data)
tbn.bias[:] = torch.tensor(bn.bias.data)
tbn.running_mean[:] = torch.tensor(bn.running_mean.data)
tbn.running_var[:] = torch.tensor(bn.running_var.data)
# trial
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.data))
# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()