diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index cc96095c93..85a0a502a0 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -410,7 +410,8 @@ def train_retinanet(): @TinyJit def _eval_step(model, x, **kwargs): out = model(normalize(x, GPUS), **kwargs) - return out.realize() + # reassemble on GPUS[0] before sending back to CPU for speed + return out.to(GPUS[0]).realize() # ** hyperparameters ** config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))