mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama: stochastic rounding (#15456)
This commit is contained in:
@@ -1,7 +1,21 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
from tinygrad.helpers import FUSE_OPTIM
|
||||
from tinygrad.helpers import FUSE_OPTIM, getenv
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
|
||||
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
|
||||
|
||||
def stochastic_round_bf16(x:Tensor) -> Tensor:
|
||||
bits = x.bitcast(dtypes.uint32)
|
||||
if isinstance(x.device, tuple):
|
||||
shape = x.uop.shard_shape if x.uop.axis is not None else x.shape
|
||||
noise = Tensor(UOp(Ops.MSTACK, dtypes.default_float, tuple(Tensor.rand(*shape, device=d).uop for d in x.device)))
|
||||
else:
|
||||
noise = x.rand_like()
|
||||
noise = (noise * 0xFFFF).cast(dtypes.uint32)
|
||||
return ((bits + noise) & 0xFFFF0000).bitcast(dtypes.float32).cast(dtypes.bfloat16)
|
||||
|
||||
class GradAccClipAdamW(Optimizer):
|
||||
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM):
|
||||
@@ -11,8 +25,7 @@ class GradAccClipAdamW(Optimizer):
|
||||
self.m = self._new_optim_param()
|
||||
self.v = self._new_optim_param()
|
||||
self.grad_acc, self.clip_norm = grad_acc, clip_norm
|
||||
# fp32 master weights for mixed precision training
|
||||
self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if self.params[0].dtype != dtypes.float32 else None
|
||||
self.master_params:list[Tensor]|None = [p.float().contiguous() for p in self.params] if MASTER_WEIGHTS and self.params[0].dtype != dtypes.float32 else None
|
||||
|
||||
def fstep(self, grads:list[Tensor]):
|
||||
if self.fused:
|
||||
@@ -47,12 +60,14 @@ class GradAccClipAdamW(Optimizer):
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
for i, g in enumerate(grads):
|
||||
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
||||
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
||||
m_hat = (self.m[i] / (1.0 - self.b1_t)).cast(self.m[i].dtype)
|
||||
v_hat = (self.v[i] / (1.0 - self.b2_t)).cast(self.v[i].dtype)
|
||||
m_new = self.b1 * self.m[i].float() + (1.0 - self.b1) * g.float()
|
||||
v_new = self.b2 * self.v[i].float() + (1.0 - self.b2) * (g.float() * g.float())
|
||||
self.m[i].assign(m_new.cast(self.m[i].dtype))
|
||||
self.v[i].assign(v_new.cast(self.v[i].dtype))
|
||||
m_hat = m_new / (1.0 - self.b1_t)
|
||||
v_hat = v_new / (1.0 - self.b2_t)
|
||||
up = m_hat / (v_hat.sqrt() + self.eps)
|
||||
ret.append((self.lr * up).cast(g.dtype))
|
||||
ret.append(self.lr * up)
|
||||
return ret, [self.b1_t, self.b2_t] + self.m + self.v + [total_norm]
|
||||
|
||||
def _apply_update(self, t:Tensor, up:Tensor, master:Tensor|None=None) -> Tensor:
|
||||
@@ -61,4 +76,5 @@ class GradAccClipAdamW(Optimizer):
|
||||
up = up.float().shard_like(w) + self.lr.to(w.device) * wd * w.detach()
|
||||
new_w = w.detach() - up
|
||||
if master is not None: master.assign(new_w)
|
||||
if STOCHASTIC_ROUND and t.dtype == dtypes.bfloat16: return stochastic_round_bf16(new_w)
|
||||
return new_w.cast(t.dtype)
|
||||
|
||||
@@ -15,7 +15,7 @@ export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
||||
@@ -14,7 +14,7 @@ export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="float32"
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} MP=${MP:-1} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-4}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user