llama: use fused norm mul quantize for w13 (#15878)

This commit is contained in:
wozeparrot
2026-04-23 12:27:41 +08:00
committed by GitHub
parent 0c3260d5d9
commit d3cbd781d9
2 changed files with 10 additions and 6 deletions

View File

@@ -137,9 +137,7 @@ class FlatTransformer:
x, rrms = rmsnorm(x, self.norm_eps)
saves.extend([x, rrms])
x = x * attention_norm
xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
from extra.amax.cast_amax import fused_mul_quantize_fp8
amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
@@ -178,9 +176,15 @@ class FlatTransformer:
x, rrms = rmsnorm(x, self.norm_eps)
saves.extend([x, rrms])
x = x * ffn_norm
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
from extra.amax.cast_amax import fused_mul_quantize_fp8
amax_s13 = amax_x13 if amax_x13 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
x_fp8_13, x_inv_scale_13, new_amax_x13 = fused_mul_quantize_fp8(x, ffn_norm, amax_s13, FP8_DTYPE)
x_w13, *ret = matmul(None, w13, w_inv_scale=s_13, x_fp8=x_fp8_13, x_scale=x_inv_scale_13, x_new_amax=new_amax_x13)
else:
x = x * ffn_norm
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [x_w13])

View File

@@ -81,7 +81,7 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal()
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, _scalar_amax(amax_buf)
# ** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize)
@@ -129,5 +129,5 @@ def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_mul_quantize_fp8_bwd)
new_amax = _scalar_amax(amax_buf)
inv_scale = (new_amax.float() + 1e-8) / FP8_MAX
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, new_amax