diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index f9023bd6e7..5817b484aa 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -395,9 +395,9 @@ def train_retinanet(): optim.zero_grad() losses = model(normalize(x, GPUS), **kwargs) - loss = (sum([l for l in losses.values()]) * loss_scaler) + loss = sum([l for l in losses.values()]) - loss.backward() + (loss * loss_scaler).backward() for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler optim.step()