From 95cdbbf237271d8e32b39af338b8dc58b47ecd75 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Wed, 22 Jan 2025 12:31:29 -0800 Subject: [PATCH] add jit to the training loop --- examples/mlperf/model_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index fd64b16512..76032db841 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -385,6 +385,7 @@ def train_retinanet(): return LambdaLR(optim, _lr_lambda) @Tensor.train() + @TinyJit def _train_step(model, optim, lr_scheduler, x, **kwargs): optim.zero_grad()