mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor cleanup
This commit is contained in:
@@ -351,9 +351,6 @@ def train_retinanet():
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
|
||||
# ** hyperparameters **
|
||||
LR = 1e-4
|
||||
|
||||
def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys):
|
||||
model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
||||
for model_layer in model_layers:
|
||||
@@ -362,6 +359,9 @@ def train_retinanet():
|
||||
layer:Tensor = get_child(backbone, loaded_key)
|
||||
layer.requires_grad = False
|
||||
|
||||
# ** hyperparameters **
|
||||
LR = 1e-4
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2d
|
||||
|
||||
|
||||
Reference in New Issue
Block a user