mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
BatchNorm2D -> BatchNorm2d (#558)
* BatchNorm2D -> BatchNorm2d * Fix typo
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2D, optim
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
from extra.utils import get_parameters # TODO: move to optim
|
||||
import unittest
|
||||
|
||||
@@ -38,9 +38,9 @@ class TestBatchnorm(unittest.TestCase):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 24, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2D(24, track_running_stats=False)
|
||||
self.bn = BatchNorm2d(24, track_running_stats=False)
|
||||
self.c2 = Conv2d(24, 32, 3, padding=1, bias=False)
|
||||
self.bn2 = BatchNorm2D(32, track_running_stats=False)
|
||||
self.bn2 = BatchNorm2d(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
x = self.bn(self.c(x)).relu()
|
||||
return self.bn2(self.c2(x)).relu()
|
||||
@@ -51,7 +51,7 @@ class TestBatchnorm(unittest.TestCase):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2D(32, track_running_stats=False)
|
||||
self.bn = BatchNorm2d(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
return self.bn(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
@@ -59,4 +59,4 @@ class TestBatchnorm(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user