mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama: per device amax (#15735)
This commit is contained in:
@@ -23,8 +23,20 @@ FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
FP8_MAX = 448.0
|
||||
|
||||
# per-device abs max without allreduce (matches TE delayed scaling behavior)
|
||||
@functools.cache
|
||||
def _local_abs_max_fxn(x_p, device):
|
||||
x = Tensor(x_p, device=device)
|
||||
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
|
||||
return (inner.abs().max(),)
|
||||
|
||||
def _local_abs_max(x:Tensor) -> Tensor:
|
||||
param = x.as_param(0)
|
||||
fxn = _local_abs_max_fxn(param.uop, x.device)
|
||||
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
|
||||
|
||||
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||
new_amax = x.abs().max().detach()
|
||||
new_amax = (_local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
|
||||
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
|
||||
x_scaled = x * scale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
|
||||
@@ -193,6 +205,10 @@ class FlatTransformer:
|
||||
self.tok_embeddings.weight.shard_(device, axis=0).realize()
|
||||
self.output.shard_(device, axis=1).realize()
|
||||
self.freqs_cis.shard_(device, axis=None).realize()
|
||||
if FP8:
|
||||
for name in self._fp8_amax:
|
||||
for i in range(len(self._fp8_amax[name])):
|
||||
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor):
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
@@ -587,6 +587,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape))))
|
||||
|
||||
def _shard(self, axis:int) -> UOp:
|
||||
if len(self.shape) == 0: return self # scalars broadcast, no sharding needed
|
||||
dcount = len(self.device)
|
||||
dnum = UOp.variable("_device_num", 0, dcount-1)
|
||||
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
|
||||
|
||||
Reference in New Issue
Block a user