From c14473f87da8f28683fb6173111d81a1e79a24c8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 30 Oct 2020 08:19:58 -0700 Subject: [PATCH] unit test for batchnorm2d --- examples/efficientnet.py | 3 +++ test/test_nn.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 test/test_nn.py diff --git a/examples/efficientnet.py b/examples/efficientnet.py index cd347b56e6..44f9c9962a 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -125,6 +125,9 @@ if __name__ == "__main__": print(img.shape, img.dtype) # run the net + import time + st = time.time() out = model.forward(Tensor(img)) + print("did inference in %.2f s" % (time.time()-st)) print(np.argmax(out.data), np.max(out.data)) diff --git a/test/test_nn.py b/test/test_nn.py new file mode 100644 index 0000000000..6ebb9d4afb --- /dev/null +++ b/test/test_nn.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +import unittest +import numpy as np +from tinygrad.nn import * +import torch + +class TestNN(unittest.TestCase): + def test_batchnorm2d(self): + sz = 4 + + # create in tinygrad + bn = BatchNorm2D(sz, eps=1e-5) + bn.weight = Tensor.randn(sz) + bn.bias = Tensor.randn(sz) + bn.running_mean = Tensor.randn(sz) + bn.running_var = Tensor.randn(sz) + bn.running_var.data[bn.running_var.data < 0] = 0 + + # create in torch + tbn = torch.nn.BatchNorm2d(sz).eval() + tbn.weight[:] = torch.tensor(bn.weight.data) + tbn.bias[:] = torch.tensor(bn.bias.data) + tbn.running_mean[:] = torch.tensor(bn.running_mean.data) + tbn.running_var[:] = torch.tensor(bn.running_var.data) + + # trial + inn = Tensor.randn(2, sz, 3, 3) + + # in tinygrad + outt = bn(inn) + + # in torch + toutt = tbn(torch.tensor(inn.data)) + + # close + np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=1e-5) + + +if __name__ == '__main__': + unittest.main()