From 1ca178f3796c47e9cfd06165fb542ee450b3d57a Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 26 Mar 2026 09:16:31 +0800 Subject: [PATCH] llama: stochastic rounding (#15456) --- examples/mlperf/optim.py | 32 ++++++++++++++----- .../tinybox_8xMI350X/dev_beam.sh | 2 +- .../tinybox_8xMI350X/dev_run.sh | 2 +- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index 1b19441d99..cb79e117a2 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -1,7 +1,21 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.nn.optim import Optimizer -from tinygrad.helpers import FUSE_OPTIM +from tinygrad.helpers import FUSE_OPTIM, getenv +from tinygrad.uop.ops import UOp, Ops + +STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0) +MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0) + +def stochastic_round_bf16(x:Tensor) -> Tensor: + bits = x.bitcast(dtypes.uint32) + if isinstance(x.device, tuple): + shape = x.uop.shard_shape if x.uop.axis is not None else x.shape + noise = Tensor(UOp(Ops.MSTACK, dtypes.default_float, tuple(Tensor.rand(*shape, device=d).uop for d in x.device))) + else: + noise = x.rand_like() + noise = (noise * 0xFFFF).cast(dtypes.uint32) + return ((bits + noise) & 0xFFFF0000).bitcast(dtypes.float32).cast(dtypes.bfloat16) class GradAccClipAdamW(Optimizer): def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM): @@ -11,8 +25,7 @@ class GradAccClipAdamW(Optimizer): self.m = self._new_optim_param() self.v = self._new_optim_param() self.grad_acc, self.clip_norm = grad_acc, clip_norm - # fp32 master weights for mixed precision training - self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if self.params[0].dtype != dtypes.float32 else None + self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if MASTER_WEIGHTS and self.params[0].dtype != dtypes.float32 else None def fstep(self, grads:list[Tensor]): if self.fused: @@ -47,12 +60,14 @@ class GradAccClipAdamW(Optimizer): self.b1_t *= self.b1 self.b2_t *= self.b2 for i, g in enumerate(grads): - self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype)) - self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype)) - m_hat = (self.m[i] / (1.0 - self.b1_t)).cast(self.m[i].dtype) - v_hat = (self.v[i] / (1.0 - self.b2_t)).cast(self.v[i].dtype) + m_new = self.b1 * self.m[i].float() + (1.0 - self.b1) * g.float() + v_new = self.b2 * self.v[i].float() + (1.0 - self.b2) * (g.float() * g.float()) + self.m[i].assign(m_new.cast(self.m[i].dtype)) + self.v[i].assign(v_new.cast(self.v[i].dtype)) + m_hat = m_new / (1.0 - self.b1_t) + v_hat = v_new / (1.0 - self.b2_t) up = m_hat / (v_hat.sqrt() + self.eps) - ret.append((self.lr * up).cast(g.dtype)) + ret.append(self.lr * up) return ret, [self.b1_t, self.b2_t] + self.m + self.v + [total_norm] def _apply_update(self, t:Tensor, up:Tensor, master:Tensor|None=None) -> Tensor: @@ -61,4 +76,5 @@ class GradAccClipAdamW(Optimizer): up = up.float().shard_like(w) + self.lr.to(w.device) * wd * w.detach() new_w = w.detach() - up if master is not None: master.assign(new_w) + if STOCHASTIC_ROUND and t.dtype == dtypes.bfloat16: return stochastic_round_bf16(new_w) return new_w.cast(t.dtype) diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh index 277c340f5b..0d67d7086c 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh @@ -15,7 +15,7 @@ export ASM_GEMM=${ASM_GEMM:-1} export WQKV=${WQKV:-0} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" -export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2} +export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4} export GBS=$((BS * GRADIENT_ACC_STEPS)) export MODEL="llama3" diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index 0741efbb19..5c175514c8 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -14,7 +14,7 @@ export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} export WQKV=${WQKV:-0} -export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="float32" +export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4} export GBS=$((BS * GRADIENT_ACC_STEPS))