mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
BatchNorm2D -> BatchNorm2d (#558)
* BatchNorm2D -> BatchNorm2d * Fix typo
This commit is contained in:
@@ -46,7 +46,7 @@ class TestOpt(unittest.TestCase):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,32,4,4)
|
||||
bn = nn.BatchNorm2D(32, track_running_stats=False)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_bn = bn(img).realize()
|
||||
print(img_bn)
|
||||
@@ -73,7 +73,7 @@ class TestOpt(unittest.TestCase):
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2D(32, track_running_stats=False)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, bn]))
|
||||
with CLCache():
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
@@ -86,7 +86,7 @@ class TestOpt(unittest.TestCase):
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2D(32, track_running_stats=False)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
# precache the bn
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
with CLCache():
|
||||
@@ -97,7 +97,7 @@ class TestOpt(unittest.TestCase):
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,8,8)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2D(32, track_running_stats=False)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
with CLCache():
|
||||
img_conv = bn(c1(img)).relu().realize()
|
||||
print(img_conv)
|
||||
@@ -132,4 +132,4 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user