diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 82ec80b5a0..e1d7d4a022 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -593,7 +593,8 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te optimizer.step() scheduler.step() # TODO: no to("CPU") here because it blocks and messes the python time - return loss.realize(), global_norm.realize(), optimizer.optimizers[0].lr.realize() + Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr) + return loss, global_norm, optimizer.optimizers[0].lr @TinyJit def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, @@ -604,8 +605,10 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \ model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) - return masked_lm_accuracy.to("CPU").realize(), seq_relationship_accuracy.realize().to("CPU"), \ - masked_lm_loss.to("CPU").realize(), next_sentence_loss.to("CPU").realize() + for t in [masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss]: + t.to_("CPU") + Tensor.realize(masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss) + return masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss def train_bert(): # NOTE: pip install tensorflow, wandb required