diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index fd64b16512..76032db841 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -385,6 +385,7 @@ def train_retinanet(): return LambdaLR(optim, _lr_lambda) @Tensor.train() + @TinyJit def _train_step(model, optim, lr_scheduler, x, **kwargs): optim.zero_grad()