diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 0c705edc65..6b4f9b4c87 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -497,6 +497,10 @@ def train_retinanet(): cl = time.perf_counter() if BENCHMARK: step_times.append(cl - st) + if not math.isfinite(loss): + print("loss is nan") + return + tqdm.write( f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "