From d281e64411ddcdbc4c542107639b4e1333d58013 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Wed, 2 Oct 2024 04:46:08 -0700 Subject: [PATCH] add optimizer --- examples/mlperf/model_train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 77ee680797..adf09ccba7 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -347,9 +347,13 @@ def train_retinanet(): from extra.models.retinanet import RetinaNet from extra.models import resnet from tinygrad.helpers import get_child + from tinygrad.nn.optim import Adam NUM_CLASSES = len(MLPERF_CLASSES) + # ** hyperparameters ** + LR = 1e-4 + def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys): model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] for model_layer in model_layers: @@ -368,6 +372,10 @@ def train_retinanet(): model = RetinaNet(backbone, num_classes=NUM_CLASSES) + # ** optimizer ** + params = get_parameters(model) + optim = Adam(params, lr=LR) + def train_unet3d(): """ Trains the UNet3D model.