mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Import get_parameters from tinygrad.nn (#559)
* get_parameter is in optim * Update all imports for get_parameters * Clean up * use optim.get_paramters
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import unittest
|
||||
import time
|
||||
import tinygrad.nn.optim as optim
|
||||
import numpy as np
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.training import train
|
||||
from extra.utils import get_parameters
|
||||
from models.efficientnet import EfficientNet
|
||||
from models.transformer import Transformer
|
||||
from models.vit import ViT
|
||||
@@ -14,7 +13,7 @@ from models.resnet import ResNet18
|
||||
BS = getenv("BS", 2)
|
||||
|
||||
def train_one_step(model,X,Y):
|
||||
params = get_parameters(model)
|
||||
params = optim.get_parameters(model)
|
||||
pcount = 0
|
||||
for p in params:
|
||||
pcount += np.prod(p.shape)
|
||||
|
||||
Reference in New Issue
Block a user