mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Refactor nn.optim (#1091)
* Refactor: nn.optim.py * Refactor: nn.optim.py; Fix all tests * Refactor: Replace all optim.get_parameters() * Refactor: Revert list comp. * Refactor: Replace optim.get_state_dict * Refactor: Change quickstart.md
This commit is contained in:
@@ -3,6 +3,7 @@ import gc
|
||||
import time
|
||||
from tqdm import trange
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
@@ -22,7 +23,7 @@ CLCACHE = getenv("CLCACHE", 0)
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = optim.get_parameters(model)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters: p.realize()
|
||||
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else: optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Reference in New Issue
Block a user