llama amax outside (#15670)

This commit is contained in:
wozeparrot
2026-04-10 14:08:03 +08:00
committed by GitHub
parent 16f3448b26
commit 55bcd7cc9e
2 changed files with 59 additions and 34 deletions

View File

@@ -25,28 +25,25 @@ FP8_GRAD_DTYPE = dtypes.fp8e5m2
FP8_MAX = 448.0
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
if amax_state is not None:
scale = FP8_MAX / (amax_state + 1e-8)
amax_state.assign(x.abs().max().detach())
else:
scale = FP8_MAX / (x.abs().max().detach() + 1e-8)
new_amax = x.abs().max().detach()
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
x_scaled = x * scale
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor:
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T)
return x @ w.T
x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x)
w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w)
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
return (x @ w.T,)
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
w_fp8, w_scale, w_new_amax = quantize_fp8(w, amax_state=amax_w)
combined_scale = x_scale * w_scale
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale)
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale), x_new_amax, w_new_amax
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale, x_new_amax, w_new_amax
def rmsnorm(x_in:Tensor, eps:float):
x = x_in.float()
@@ -108,19 +105,28 @@ class FlatTransformer:
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None):
x = rmsnorm(x, self.norm_eps) * attention_norm
bsz, seqlen, _ = x.shape
new_amaxs = []
if wqkv is not None:
xqkv = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
new_amaxs.extend(amaxs)
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
else:
assert wq is not None and wk is not None and wv is not None
xq = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq).reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq)
new_amaxs.extend(amaxs)
xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk)
new_amaxs.extend(amaxs)
xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv)
new_amaxs.extend(amaxs)
xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
@@ -131,14 +137,24 @@ class FlatTransformer:
else:
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo)
out, *amaxs = matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo)
new_amaxs.extend(amaxs)
return (out, *new_amaxs)
def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor,
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
x = rmsnorm(x, self.norm_eps) * ffn_norm
x_w1 = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1).silu()
x_w3 = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3)
return matmul(x_w1 * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2)
new_amaxs = []
x_w1, *amaxs = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1)
new_amaxs.extend(amaxs)
x_w3, *amaxs = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3)
new_amaxs.extend(amaxs)
out, *amaxs = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2)
new_amaxs.extend(amaxs)
return (out, *new_amaxs)
@function(precompile=True, precompile_backward=True)
def run_layer(self, x:Tensor, freqs_cis:Tensor,
@@ -148,13 +164,16 @@ class FlatTransformer:
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None,
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
h = x + self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
amax_xo=amax_xo, amax_wo=amax_wo)
return h + self.feed_forward(h, ffn_norm, w1, w2, w3,
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
amax_x3=amax_x3, amax_w3=amax_w3)
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
amax_xo=amax_xo, amax_wo=amax_wo)
h = x + attn
ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3,
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
amax_x3=amax_x3, amax_w3=amax_w3)
h = h + ffn
return (h, *attn_amaxs, *ffn_amaxs)
def shard(self, device:tuple[str, ...], mp:bool=False):
from tinygrad.nn.state import get_parameters
@@ -196,11 +215,17 @@ class FlatTransformer:
"amax_x1": a["x1"][i], "amax_w1": a["w1"][i],
"amax_x2": a["x2"][i], "amax_w2": a["w2"][i],
"amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {}
h = self.run_layer(h, freqs_cis,
self.attention_norm[i], self.wo[i],
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
**attn_kwargs, **amax_attn, **amax_layer)
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False).contiguous_backward()
h, *amaxs = self.run_layer(h, freqs_cis,
self.attention_norm[i], self.wo[i],
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
**attn_kwargs, **amax_attn, **amax_layer)
if a:
if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
for name, new_val in zip(amax_names, amaxs):
a[name][i].assign(new_val)
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False)[0].contiguous_backward()
return logits
def _get_pads(uop:UOp) -> list[UOp]:

View File

@@ -2704,7 +2704,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
a_t, b_t, g_t, s_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device), Tensor(scale, device=a.device)
g_t = g_t[:a.shape[0]]
# backward GEMMs in fp8 with scale applied inside kernel to prevent bf16 overflow
g_fp8, g_scale = quantize_fp8(g_t)
g_fp8, g_scale, _ = quantize_fp8(g_t)
bw_scale = g_scale * s_t
# dgrad: g_fp8 @ weight (asm_gemm computes a@b)
grad_a = asm_gemm(g_fp8, b_t, combined_scale=bw_scale)