diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index 25df87ab3b..7961bd8fb8 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -21,12 +21,13 @@ class GradAccClipAdamW(Optimizer): total_norm = grads[0].float().square().sum().sqrt() grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype) else: + for i in range(len(grads)): + grads[i] = grads[i] / self.grad_acc total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device) for g in grads: total_norm += g.float().square().sum() total_norm = total_norm.sqrt() for i in range(len(grads)): - grads[i] = grads[i] / self.grad_acc grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype) ret = []