all three

This commit is contained in:
f0ti
2020-10-23 11:53:01 +02:00
parent 6a38ccb6b0
commit c5f726ec2e

View File

@@ -77,23 +77,22 @@ class TestMNIST(unittest.TestCase):
assert accuracy > 0.95
# models
if os.getenv("CONV") == "1":
model = TinyConvNet()
optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
steps = 400
train(model, optim, steps)
evaluate(model)
else:
model = TinyBobNet()
optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001)
steps = 1000
train(model, optim, steps)
evaluate(model)
# RMSprop
optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001)
train(model, optim, steps)
evaluate(model)
model = TinyConvNet()
optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
steps = 400
train(model, optim, steps)
evaluate(model)
model = TinyBobNet()
optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001)
steps = 1000
train(model, optim, steps)
evaluate(model)
# RMSprop
optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001)
train(model, optim, steps)
evaluate(model)
if __name__ == '__main__':
unittest.main()