From c5f726ec2ef61b4ccd77ed627307176045fe4ba0 Mon Sep 17 00:00:00 2001 From: f0ti Date: Fri, 23 Oct 2020 11:53:01 +0200 Subject: [PATCH] all three --- test/test_mnist.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/test/test_mnist.py b/test/test_mnist.py index 1e60a4c058..1a01132745 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -77,23 +77,22 @@ class TestMNIST(unittest.TestCase): assert accuracy > 0.95 # models - if os.getenv("CONV") == "1": - model = TinyConvNet() - optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001) - steps = 400 - train(model, optim, steps) - evaluate(model) - else: - model = TinyBobNet() - optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001) - steps = 1000 - train(model, optim, steps) - evaluate(model) - - # RMSprop - optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001) - train(model, optim, steps) - evaluate(model) + model = TinyConvNet() + optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001) + steps = 400 + train(model, optim, steps) + evaluate(model) + + model = TinyBobNet() + optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001) + steps = 1000 + train(model, optim, steps) + evaluate(model) + + # RMSprop + optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001) + train(model, optim, steps) + evaluate(model) if __name__ == '__main__': unittest.main()