mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama: use fused norm mul quantize for w13 (#15878)
This commit is contained in:
@@ -137,9 +137,7 @@ class FlatTransformer:
|
||||
|
||||
x, rrms = rmsnorm(x, self.norm_eps)
|
||||
saves.extend([x, rrms])
|
||||
x = x * attention_norm
|
||||
|
||||
xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)
|
||||
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
|
||||
from extra.amax.cast_amax import fused_mul_quantize_fp8
|
||||
amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
|
||||
@@ -178,9 +176,15 @@ class FlatTransformer:
|
||||
|
||||
x, rrms = rmsnorm(x, self.norm_eps)
|
||||
saves.extend([x, rrms])
|
||||
x = x * ffn_norm
|
||||
|
||||
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
|
||||
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
|
||||
from extra.amax.cast_amax import fused_mul_quantize_fp8
|
||||
amax_s13 = amax_x13 if amax_x13 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
|
||||
x_fp8_13, x_inv_scale_13, new_amax_x13 = fused_mul_quantize_fp8(x, ffn_norm, amax_s13, FP8_DTYPE)
|
||||
x_w13, *ret = matmul(None, w13, w_inv_scale=s_13, x_fp8=x_fp8_13, x_scale=x_inv_scale_13, x_new_amax=new_amax_x13)
|
||||
else:
|
||||
x = x * ffn_norm
|
||||
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
|
||||
new_amaxs.extend(ret[:1])
|
||||
saves.extend(ret[1:] + [x_w13])
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T
|
||||
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
|
||||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal()
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, _scalar_amax(amax_buf)
|
||||
|
||||
# ** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize)
|
||||
@@ -129,5 +129,5 @@ def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype
|
||||
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_mul_quantize_fp8_bwd)
|
||||
new_amax = _scalar_amax(amax_buf)
|
||||
inv_scale = (new_amax.float() + 1e-8) / FP8_MAX
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, new_amax
|
||||
|
||||
Reference in New Issue
Block a user