From 480ad264a4cb782296642d900cb02ece910c0294 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 15 Apr 2026 10:01:17 +0800 Subject: [PATCH] llama: per device amax (#15735) --- examples/mlperf/models/flat_llama.py | 18 +++++++++++++++++- tinygrad/uop/ops.py | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 069e18df6f..a92d057d72 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -23,8 +23,20 @@ FP8_DTYPE = dtypes.fp8e4m3 FP8_GRAD_DTYPE = dtypes.fp8e5m2 FP8_MAX = 448.0 +# per-device abs max without allreduce (matches TE delayed scaling behavior) +@functools.cache +def _local_abs_max_fxn(x_p, device): + x = Tensor(x_p, device=device) + inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x + return (inner.abs().max(),) + +def _local_abs_max(x:Tensor) -> Tensor: + param = x.as_param(0) + fxn = _local_abs_max_fxn(param.uop, x.device) + return Tensor(fxn[0].uop.call(x.uop).gettuple(0)) + def quantize_fp8(x:Tensor, amax_state:Tensor|None=None): - new_amax = x.abs().max().detach() + new_amax = (_local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach() scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8) x_scaled = x * scale x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE @@ -193,6 +205,10 @@ class FlatTransformer: self.tok_embeddings.weight.shard_(device, axis=0).realize() self.output.shard_(device, axis=1).realize() self.freqs_cis.shard_(device, axis=None).realize() + if FP8: + for name in self._fp8_amax: + for i in range(len(self._fp8_amax[name])): + self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False) def __call__(self, tokens:Tensor): h = self.tok_embeddings(tokens) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 55d9918473..72857f290e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -587,6 +587,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape)))) def _shard(self, axis:int) -> UOp: + if len(self.shape) == 0: return self # scalars broadcast, no sharding needed dcount = len(self.device) dnum = UOp.variable("_device_num", 0, dcount-1) if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")