Refactor nn.optim (#1091)

* Refactor: nn.optim.py

* Refactor: nn.optim.py; Fix all tests

* Refactor: Replace all optim.get_parameters()

* Refactor: Revert list comp.

* Refactor: Replace optim.get_state_dict

* Refactor: Change quickstart.md
This commit is contained in:
Reza Rezvan
2023-07-03 00:07:30 +02:00
committed by GitHub
parent 10f1aeb144
commit 8ae9a054ae
15 changed files with 44 additions and 36 deletions

View File

@@ -2,6 +2,7 @@
import numpy as np
from PIL import Image
from tinygrad.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv
from extra.training import train, evaluate
@@ -37,7 +38,7 @@ if __name__ == "__main__":
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
])
for _ in range(5):
optimizer = optim.SGD(optim.get_parameters(model), lr=lr, momentum=0.9)
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
lr /= 1.2