diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 3bf5541b56..3256e30598 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -433,7 +433,7 @@ def train_retinanet(): model = RetinaNet(backbone, num_classes=NUM_CLASSES) params = get_parameters(model) - for p in params: p.realize().to_(GPUS) + for p in params: p.to_(GPUS) # ** optimizer ** optim = Adam(params, lr=lr)