From a36a26d4ed1cf32df6eb869213ea056c6d1fed09 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Mon, 23 Feb 2026 22:25:13 -0800 Subject: [PATCH] llama3: optim does grad acc in correct order (#14965) --- examples/mlperf/optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index 25df87ab3b..7961bd8fb8 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -21,12 +21,13 @@ class GradAccClipAdamW(Optimizer): total_norm = grads[0].float().square().sum().sqrt() grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype) else: + for i in range(len(grads)): + grads[i] = grads[i] / self.grad_acc total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device) for g in grads: total_norm += g.float().square().sum() total_norm = total_norm.sqrt() for i in range(len(grads)): - grads[i] = grads[i] / self.grad_acc grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype) ret = []