diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 9d149a0135..17859b7b0b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1,4 +1,4 @@ -import os, time, math, functools, random +import os, time, math, functools, random, contextlib from pathlib import Path import multiprocessing @@ -1281,9 +1281,15 @@ def train_maskrcnn(): if __name__ == "__main__": multiprocessing.set_start_method('spawn') + + if getenv("INITMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_INIT) + elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) + else: bench_log_manager = contextlib.nullcontext() + with Tensor.train(): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): nm = f"train_{m}" if nm in globals(): print(f"training {m}") - with Profiling(enabled=getenv("PYPROFILE")): globals()[nm]() + with bench_log_manager: + with Profiling(enabled=getenv("PYPROFILE")): globals()[nm]() diff --git a/extra/bench_log.py b/extra/bench_log.py index 8294bdfc2b..72a23fe13e 100644 --- a/extra/bench_log.py +++ b/extra/bench_log.py @@ -14,6 +14,8 @@ class BenchEvent(Enum): LOAD_WEIGHTS = "load_weights" STEP = "step" FULL = "full" + MLPERF_INIT = "mlperf_init" + MLPERF_RUN = "mlperf_setup" class InstantBenchEvent(Enum): GFLOPS = "gflops"