llama3: optim does grad acc in correct order (#14965)

This commit is contained in:
wozeparrot
2026-02-23 22:25:13 -08:00
committed by GitHub
parent e2b1f2620d
commit a36a26d4ed

View File

@@ -21,12 +21,13 @@ class GradAccClipAdamW(Optimizer):
total_norm = grads[0].float().square().sum().sqrt()
grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype)
else:
for i in range(len(grads)):
grads[i] = grads[i] / self.grad_acc
total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device)
for g in grads:
total_norm += g.float().square().sum()
total_norm = total_norm.sqrt()
for i in range(len(grads)):
grads[i] = grads[i] / self.grad_acc
grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)
ret = []