mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama3: optim does grad acc in correct order (#14965)
This commit is contained in:
@@ -21,12 +21,13 @@ class GradAccClipAdamW(Optimizer):
|
||||
total_norm = grads[0].float().square().sum().sqrt()
|
||||
grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype)
|
||||
else:
|
||||
for i in range(len(grads)):
|
||||
grads[i] = grads[i] / self.grad_acc
|
||||
total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device)
|
||||
for g in grads:
|
||||
total_norm += g.float().square().sum()
|
||||
total_norm = total_norm.sqrt()
|
||||
for i in range(len(grads)):
|
||||
grads[i] = grads[i] / self.grad_acc
|
||||
grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)
|
||||
|
||||
ret = []
|
||||
|
||||
Reference in New Issue
Block a user