mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
only resnet18, it's too slow otherwise
This commit is contained in:
@@ -8,7 +8,7 @@ from extra.training import train
|
||||
from extra.utils import get_parameters
|
||||
from models.efficientnet import EfficientNet
|
||||
from models.transformer import Transformer
|
||||
from models.resnet import ResNet18, ResNet34, ResNet50
|
||||
from models.resnet import ResNet18
|
||||
|
||||
BS = int(os.getenv("BS", "4"))
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestTrain(unittest.TestCase):
|
||||
def test_resnet(self):
|
||||
X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
|
||||
Y = np.zeros((BS), dtype=np.int32)
|
||||
for resnet_v in [ResNet18, ResNet34, ResNet50]:
|
||||
for resnet_v in [ResNet18]:
|
||||
model = resnet_v(num_classes=1000, pretrained=True)
|
||||
train_one_step(model, X, Y)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user