mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Refactor nn.optim (#1091)
* Refactor: nn.optim.py * Refactor: nn.optim.py; Fix all tests * Refactor: Replace all optim.get_parameters() * Refactor: Revert list comp. * Refactor: Replace optim.get_state_dict * Refactor: Change quickstart.md
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.training import train, evaluate
|
||||
@@ -37,7 +38,7 @@ if __name__ == "__main__":
|
||||
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
||||
])
|
||||
for _ in range(5):
|
||||
optimizer = optim.SGD(optim.get_parameters(model), lr=lr, momentum=0.9)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
||||
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
||||
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
||||
lr /= 1.2
|
||||
|
||||
Reference in New Issue
Block a user