diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 19fa1a8e07..ef92c98a69 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -687,9 +687,9 @@ def train_bert(): model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None) - for _, x in get_state_dict(model).items(): - x.realize().to_(GPUS) parameters = get_parameters(model) + for p in parameters: + p.to_(GPUS) # ** Log run config ** for key, value in config.items(): print(f'HParam: "{key}": {value}')