mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
Import get_parameters from tinygrad.nn (#559)
* get_parameter is in optim * Update all imports for get_parameters * Clean up * use optim.get_paramters
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
from extra.utils import get_parameters # TODO: move to optim
|
||||
import unittest
|
||||
|
||||
def model_step(lm):
|
||||
Tensor.training = True
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
optimizer = optim.SGD(optim.get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user