mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
all three
This commit is contained in:
@@ -77,23 +77,22 @@ class TestMNIST(unittest.TestCase):
|
||||
assert accuracy > 0.95
|
||||
|
||||
# models
|
||||
if os.getenv("CONV") == "1":
|
||||
model = TinyConvNet()
|
||||
optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
|
||||
steps = 400
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
else:
|
||||
model = TinyBobNet()
|
||||
optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
steps = 1000
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
|
||||
# RMSprop
|
||||
optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001)
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
model = TinyConvNet()
|
||||
optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
|
||||
steps = 400
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
steps = 1000
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
|
||||
# RMSprop
|
||||
optim = tinygrad_optim.RMSprop([model.l1, model.l2], lr=0.001)
|
||||
train(model, optim, steps)
|
||||
evaluate(model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user