add optimizer

This commit is contained in:
Francis Lata
2024-10-02 04:46:08 -07:00
parent b8e24b4f4d
commit d281e64411

View File

@@ -347,9 +347,13 @@ def train_retinanet():
from extra.models.retinanet import RetinaNet
from extra.models import resnet
from tinygrad.helpers import get_child
from tinygrad.nn.optim import Adam
NUM_CLASSES = len(MLPERF_CLASSES)
# ** hyperparameters **
LR = 1e-4
def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys):
model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
for model_layer in model_layers:
@@ -368,6 +372,10 @@ def train_retinanet():
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
# ** optimizer **
params = get_parameters(model)
optim = Adam(params, lr=LR)
def train_unet3d():
"""
Trains the UNet3D model.