mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add jit to the training loop
This commit is contained in:
@@ -385,6 +385,7 @@ def train_retinanet():
|
||||
return LambdaLR(optim, _lr_lambda)
|
||||
|
||||
@Tensor.train()
|
||||
@TinyJit
|
||||
def _train_step(model, optim, lr_scheduler, x, **kwargs):
|
||||
optim.zero_grad()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user