diff --git a/datasets/__init__.py b/datasets/__init__.py index 4fd37a21fa..77fc0325e2 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -19,7 +19,6 @@ def fetch_cifar(train=True): cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1) tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz') if train: - # TODO: data_batch 2-5 db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)] else: db = [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")] diff --git a/test/models/test_mnist.py b/test/models/test_mnist.py index 54ae2cf356..2bc687a92e 100644 --- a/test/models/test_mnist.py +++ b/test/models/test_mnist.py @@ -94,13 +94,12 @@ class TestMNIST(unittest.TestCase): train(model, X_train, Y_train, optimizer, steps=100) assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes - @unittest.skip("slow and training batchnorm is broken") def test_conv_with_bn(self): np.random.seed(1337) model = TinyConvNet(has_batchnorm=True) - optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=100) - assert evaluate(model, X_test, Y_test) > 0.7 # TODO: batchnorm doesn't work!!! + optimizer = optim.AdamW(model.parameters(), lr=0.003) + train(model, X_train, Y_train, optimizer, steps=200) + assert evaluate(model, X_test, Y_test) > 0.94 def test_sgd(self): np.random.seed(1337)