llama: per device amax (#15735)

This commit is contained in:
wozeparrot
2026-04-15 10:01:17 +08:00
committed by GitHub
parent adc96cd724
commit 480ad264a4
2 changed files with 18 additions and 1 deletions

View File

@@ -23,8 +23,20 @@ FP8_DTYPE = dtypes.fp8e4m3
FP8_GRAD_DTYPE = dtypes.fp8e5m2
FP8_MAX = 448.0
# per-device abs max without allreduce (matches TE delayed scaling behavior)
@functools.cache
def _local_abs_max_fxn(x_p, device):
x = Tensor(x_p, device=device)
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
return (inner.abs().max(),)
def _local_abs_max(x:Tensor) -> Tensor:
param = x.as_param(0)
fxn = _local_abs_max_fxn(param.uop, x.device)
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
new_amax = x.abs().max().detach()
new_amax = (_local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
x_scaled = x * scale
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
@@ -193,6 +205,10 @@ class FlatTransformer:
self.tok_embeddings.weight.shard_(device, axis=0).realize()
self.output.shard_(device, axis=1).realize()
self.freqs_cis.shard_(device, axis=None).realize()
if FP8:
for name in self._fp8_amax:
for i in range(len(self._fp8_amax[name])):
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
def __call__(self, tokens:Tensor):
h = self.tok_embeddings(tokens)

View File

@@ -587,6 +587,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape))))
def _shard(self, axis:int) -> UOp:
if len(self.shape) == 0: return self # scalars broadcast, no sharding needed
dcount = len(self.device)
dnum = UOp.variable("_device_num", 0, dcount-1)
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")