mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama memory tweaks (#15223)
This commit is contained in:
@@ -34,10 +34,10 @@ class GradAccClipAdamW(Optimizer):
|
||||
grads[0].assign((grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype))
|
||||
else:
|
||||
for i in range(len(grads)):
|
||||
grads[i].assign(grads[i] / self.grad_acc).realize()
|
||||
total_norm = Tensor.stack(*[g.float().square().sum() for g in grads]).sum().sqrt().contiguous().realize()
|
||||
grads[i].assign(grads[i] / self.grad_acc)
|
||||
total_norm = Tensor.stack(*[g.float().square().sum() for g in grads]).sum().sqrt().contiguous()
|
||||
for i in range(len(grads)):
|
||||
grads[i].assign((grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)).realize()
|
||||
grads[i].assign((grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype))
|
||||
|
||||
ret = []
|
||||
self.b1_t *= self.b1
|
||||
@@ -45,8 +45,8 @@ class GradAccClipAdamW(Optimizer):
|
||||
for i, g in enumerate(grads):
|
||||
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
||||
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||
m_hat = (self.m[i] / (1.0 - self.b1_t)).cast(self.m[i].dtype)
|
||||
v_hat = (self.v[i] / (1.0 - self.b2_t)).cast(self.v[i].dtype)
|
||||
up = m_hat / (v_hat.sqrt() + self.eps)
|
||||
ret.append((self.lr * up).cast(g.dtype))
|
||||
return ret, [self.b1_t, self.b2_t] + self.m + self.v + [total_norm]
|
||||
|
||||
Reference in New Issue
Block a user