From 749162bd2f3331e4963997b5b181bb245745d9ef Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 13 Mar 2026 03:36:23 +0800 Subject: [PATCH] llama memory tweaks (#15223) --- examples/mlperf/optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index b6cc5ca193..4c6b4f5bc7 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -34,10 +34,10 @@ class GradAccClipAdamW(Optimizer): grads[0].assign((grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype)) else: for i in range(len(grads)): - grads[i].assign(grads[i] / self.grad_acc).realize() - total_norm = Tensor.stack(*[g.float().square().sum() for g in grads]).sum().sqrt().contiguous().realize() + grads[i].assign(grads[i] / self.grad_acc) + total_norm = Tensor.stack(*[g.float().square().sum() for g in grads]).sum().sqrt().contiguous() for i in range(len(grads)): - grads[i].assign((grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)).realize() + grads[i].assign((grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)) ret = [] self.b1_t *= self.b1 @@ -45,8 +45,8 @@ class GradAccClipAdamW(Optimizer): for i, g in enumerate(grads): self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype)) self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype)) - m_hat = self.m[i] / (1.0 - self.b1_t) - v_hat = self.v[i] / (1.0 - self.b2_t) + m_hat = (self.m[i] / (1.0 - self.b1_t)).cast(self.m[i].dtype) + v_hat = (self.v[i] / (1.0 - self.b2_t)).cast(self.v[i].dtype) up = m_hat / (v_hat.sqrt() + self.eps) ret.append((self.lr * up).cast(g.dtype)) return ret, [self.b1_t, self.b2_t] + self.m + self.v + [total_norm]