mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
unit test for batchnorm2d
This commit is contained in:
@@ -125,6 +125,9 @@ if __name__ == "__main__":
|
||||
print(img.shape, img.dtype)
|
||||
|
||||
# run the net
|
||||
import time
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(img))
|
||||
print("did inference in %.2f s" % (time.time()-st))
|
||||
print(np.argmax(out.data), np.max(out.data))
|
||||
|
||||
|
||||
40
test/test_nn.py
Normal file
40
test/test_nn.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.nn import *
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
def test_batchnorm2d(self):
|
||||
sz = 4
|
||||
|
||||
# create in tinygrad
|
||||
bn = BatchNorm2D(sz, eps=1e-5)
|
||||
bn.weight = Tensor.randn(sz)
|
||||
bn.bias = Tensor.randn(sz)
|
||||
bn.running_mean = Tensor.randn(sz)
|
||||
bn.running_var = Tensor.randn(sz)
|
||||
bn.running_var.data[bn.running_var.data < 0] = 0
|
||||
|
||||
# create in torch
|
||||
tbn = torch.nn.BatchNorm2d(sz).eval()
|
||||
tbn.weight[:] = torch.tensor(bn.weight.data)
|
||||
tbn.bias[:] = torch.tensor(bn.bias.data)
|
||||
tbn.running_mean[:] = torch.tensor(bn.running_mean.data)
|
||||
tbn.running_var[:] = torch.tensor(bn.running_var.data)
|
||||
|
||||
# trial
|
||||
inn = Tensor.randn(2, sz, 3, 3)
|
||||
|
||||
# in tinygrad
|
||||
outt = bn(inn)
|
||||
|
||||
# in torch
|
||||
toutt = tbn(torch.tensor(inn.data))
|
||||
|
||||
# close
|
||||
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user