only resnet18, it's too slow otherwise

This commit is contained in:
George Hotz
2021-10-30 16:48:39 -07:00
parent c05867dcbb
commit fc6597a6d9
2 changed files with 2 additions and 2 deletions

View File

@@ -8,7 +8,7 @@ from extra.training import train
from extra.utils import get_parameters
from models.efficientnet import EfficientNet
from models.transformer import Transformer
from models.resnet import ResNet18, ResNet34, ResNet50
from models.resnet import ResNet18
BS = int(os.getenv("BS", "4"))
@@ -41,7 +41,7 @@ class TestTrain(unittest.TestCase):
def test_resnet(self):
X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32)
for resnet_v in [ResNet18, ResNet34, ResNet50]:
for resnet_v in [ResNet18]:
model = resnet_v(num_classes=1000, pretrained=True)
train_one_step(model, X, Y)