From d1bc4aef94d19678a644bc4f4f7d614bc589fc99 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Tue, 21 Jan 2025 13:45:35 -0800 Subject: [PATCH] do not realize when sharding model weights --- examples/mlperf/model_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 3bf5541b56..3256e30598 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -433,7 +433,7 @@ def train_retinanet(): model = RetinaNet(backbone, num_classes=NUM_CLASSES) params = get_parameters(model) - for p in params: p.realize().to_(GPUS) + for p in params: p.to_(GPUS) # ** optimizer ** optim = Adam(params, lr=lr)