mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
WallTimeEvent for mlperf ci (#10506)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import os, time, math, functools, random
|
||||
import os, time, math, functools, random, contextlib
|
||||
from pathlib import Path
|
||||
import multiprocessing
|
||||
|
||||
@@ -1281,9 +1281,15 @@ def train_maskrcnn():
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
if getenv("INITMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_INIT)
|
||||
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
|
||||
else: bench_log_manager = contextlib.nullcontext()
|
||||
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
with Profiling(enabled=getenv("PYPROFILE")): globals()[nm]()
|
||||
with bench_log_manager:
|
||||
with Profiling(enabled=getenv("PYPROFILE")): globals()[nm]()
|
||||
|
||||
@@ -14,6 +14,8 @@ class BenchEvent(Enum):
|
||||
LOAD_WEIGHTS = "load_weights"
|
||||
STEP = "step"
|
||||
FULL = "full"
|
||||
MLPERF_INIT = "mlperf_init"
|
||||
MLPERF_RUN = "mlperf_setup"
|
||||
class InstantBenchEvent(Enum):
|
||||
GFLOPS = "gflops"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user