add jit to the training loop

This commit is contained in:
Francis Lata
2025-01-22 12:31:29 -08:00
parent efe64ebeaf
commit 95cdbbf237

View File

@@ -385,6 +385,7 @@ def train_retinanet():
return LambdaLR(optim, _lr_lambda)
@Tensor.train()
@TinyJit
def _train_step(model, optim, lr_scheduler, x, **kwargs):
optim.zero_grad()