llama: stochastic rounding (#15456)

This commit is contained in:
wozeparrot
2026-03-26 09:16:31 +08:00
committed by GitHub
parent 7c8f992894
commit 1ca178f379
3 changed files with 26 additions and 10 deletions

View File

@@ -1,7 +1,21 @@
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.nn.optim import Optimizer
from tinygrad.helpers import FUSE_OPTIM
from tinygrad.helpers import FUSE_OPTIM, getenv
from tinygrad.uop.ops import UOp, Ops
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
def stochastic_round_bf16(x:Tensor) -> Tensor:
bits = x.bitcast(dtypes.uint32)
if isinstance(x.device, tuple):
shape = x.uop.shard_shape if x.uop.axis is not None else x.shape
noise = Tensor(UOp(Ops.MSTACK, dtypes.default_float, tuple(Tensor.rand(*shape, device=d).uop for d in x.device)))
else:
noise = x.rand_like()
noise = (noise * 0xFFFF).cast(dtypes.uint32)
return ((bits + noise) & 0xFFFF0000).bitcast(dtypes.float32).cast(dtypes.bfloat16)
class GradAccClipAdamW(Optimizer):
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM):
@@ -11,8 +25,7 @@ class GradAccClipAdamW(Optimizer):
self.m = self._new_optim_param()
self.v = self._new_optim_param()
self.grad_acc, self.clip_norm = grad_acc, clip_norm
# fp32 master weights for mixed precision training
self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if self.params[0].dtype != dtypes.float32 else None
self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if MASTER_WEIGHTS and self.params[0].dtype != dtypes.float32 else None
def fstep(self, grads:list[Tensor]):
if self.fused:
@@ -47,12 +60,14 @@ class GradAccClipAdamW(Optimizer):
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, g in enumerate(grads):
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
m_hat = (self.m[i] / (1.0 - self.b1_t)).cast(self.m[i].dtype)
v_hat = (self.v[i] / (1.0 - self.b2_t)).cast(self.v[i].dtype)
m_new = self.b1 * self.m[i].float() + (1.0 - self.b1) * g.float()
v_new = self.b2 * self.v[i].float() + (1.0 - self.b2) * (g.float() * g.float())
self.m[i].assign(m_new.cast(self.m[i].dtype))
self.v[i].assign(v_new.cast(self.v[i].dtype))
m_hat = m_new / (1.0 - self.b1_t)
v_hat = v_new / (1.0 - self.b2_t)
up = m_hat / (v_hat.sqrt() + self.eps)
ret.append((self.lr * up).cast(g.dtype))
ret.append(self.lr * up)
return ret, [self.b1_t, self.b2_t] + self.m + self.v + [total_norm]
def _apply_update(self, t:Tensor, up:Tensor, master:Tensor|None=None) -> Tensor:
@@ -61,4 +76,5 @@ class GradAccClipAdamW(Optimizer):
up = up.float().shard_like(w) + self.lr.to(w.device) * wd * w.detach()
new_w = w.detach() - up
if master is not None: master.assign(new_w)
if STOCHASTIC_ROUND and t.dtype == dtypes.bfloat16: return stochastic_round_bf16(new_w)
return new_w.cast(t.dtype)

View File

@@ -15,7 +15,7 @@ export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4}
export GBS=$((BS * GRADIENT_ACC_STEPS))
export MODEL="llama3"

View File

@@ -14,7 +14,7 @@ export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="float32"
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4}
export GBS=$((BS * GRADIENT_ACC_STEPS))