Import get_parameters from tinygrad.nn (#559)

* get_parameter is in optim

* Update all imports for get_parameters

* Clean up

* use optim.get_paramters
This commit is contained in:
Jacky Lee
2023-02-17 15:22:26 -08:00
committed by GitHub
parent fae7654924
commit 9fd41632c6
13 changed files with 32 additions and 46 deletions

View File

@@ -1,12 +1,11 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d, optim
from extra.utils import get_parameters # TODO: move to optim
import unittest
def model_step(lm):
Tensor.training = True
x = Tensor.ones(8,12,128,256, requires_grad=False)
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
optimizer = optim.SGD(optim.get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum()
optimizer.zero_grad()
loss.backward()