Import get_parameters from tinygrad.nn (#559)

* get_parameter is in optim

* Update all imports for get_parameters

* Clean up

* use optim.get_paramters
This commit is contained in:
Jacky Lee
2023-02-17 15:22:26 -08:00
committed by GitHub
parent fae7654924
commit 9fd41632c6
13 changed files with 32 additions and 46 deletions

View File

@@ -1,11 +1,10 @@
import unittest
import time
import tinygrad.nn.optim as optim
import numpy as np
from tinygrad.nn import optim
from tinygrad.tensor import Device
from tinygrad.helpers import getenv
from extra.training import train
from extra.utils import get_parameters
from models.efficientnet import EfficientNet
from models.transformer import Transformer
from models.vit import ViT
@@ -14,7 +13,7 @@ from models.resnet import ResNet18
BS = getenv("BS", 2)
def train_one_step(model,X,Y):
params = get_parameters(model)
params = optim.get_parameters(model)
pcount = 0
for p in params:
pcount += np.prod(p.shape)