From 7dba815c47614fffc2d4beffd26284fc3d5dbd58 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Wed, 19 Feb 2025 20:43:02 +0000 Subject: [PATCH] fix train script --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83806abcf8..2912b62507 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -423,8 +423,8 @@ def train_retinanet(): # ** model setup ** backbone = resnet.ResNeXt50_32X4D(num_classes=None) - loaded_keys = backbone.load_from_pretrained() - _freeze_backbone_layers(backbone, 3, loaded_keys) + backbone.load_from_pretrained() + _freeze_backbone_layers(backbone, 3) model = RetinaNet(backbone, num_classes=NUM_CLASSES) params = get_parameters(model)