diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83806abcf8..2912b62507 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -423,8 +423,8 @@ def train_retinanet(): # ** model setup ** backbone = resnet.ResNeXt50_32X4D(num_classes=None) - loaded_keys = backbone.load_from_pretrained() - _freeze_backbone_layers(backbone, 3, loaded_keys) + backbone.load_from_pretrained() + _freeze_backbone_layers(backbone, 3) model = RetinaNet(backbone, num_classes=NUM_CLASSES) params = get_parameters(model)