mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add optimizer
This commit is contained in:
@@ -347,9 +347,13 @@ def train_retinanet():
|
||||
from extra.models.retinanet import RetinaNet
|
||||
from extra.models import resnet
|
||||
from tinygrad.helpers import get_child
|
||||
from tinygrad.nn.optim import Adam
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
|
||||
# ** hyperparameters **
|
||||
LR = 1e-4
|
||||
|
||||
def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys):
|
||||
model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
||||
for model_layer in model_layers:
|
||||
@@ -368,6 +372,10 @@ def train_retinanet():
|
||||
|
||||
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
|
||||
# ** optimizer **
|
||||
params = get_parameters(model)
|
||||
optim = Adam(params, lr=LR)
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
Trains the UNet3D model.
|
||||
|
||||
Reference in New Issue
Block a user