mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama amax outside (#15670)
This commit is contained in:
@@ -25,28 +25,25 @@ FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
FP8_MAX = 448.0
|
||||
|
||||
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||
if amax_state is not None:
|
||||
scale = FP8_MAX / (amax_state + 1e-8)
|
||||
amax_state.assign(x.abs().max().detach())
|
||||
else:
|
||||
scale = FP8_MAX / (x.abs().max().detach() + 1e-8)
|
||||
new_amax = x.abs().max().detach()
|
||||
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
|
||||
x_scaled = x * scale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
|
||||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor:
|
||||
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if getenv("ASM_GEMM"):
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T)
|
||||
return x @ w.T
|
||||
x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x)
|
||||
w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w)
|
||||
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
|
||||
return (x @ w.T,)
|
||||
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
w_fp8, w_scale, w_new_amax = quantize_fp8(w, amax_state=amax_w)
|
||||
combined_scale = x_scale * w_scale
|
||||
if getenv("ASM_GEMM"):
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale)
|
||||
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale
|
||||
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale), x_new_amax, w_new_amax
|
||||
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale, x_new_amax, w_new_amax
|
||||
|
||||
def rmsnorm(x_in:Tensor, eps:float):
|
||||
x = x_in.float()
|
||||
@@ -108,19 +105,28 @@ class FlatTransformer:
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
|
||||
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None):
|
||||
x = rmsnorm(x, self.norm_eps) * attention_norm
|
||||
|
||||
bsz, seqlen, _ = x.shape
|
||||
new_amaxs = []
|
||||
|
||||
if wqkv is not None:
|
||||
xqkv = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
|
||||
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
|
||||
new_amaxs.extend(amaxs)
|
||||
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
|
||||
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
else:
|
||||
assert wq is not None and wk is not None and wv is not None
|
||||
xq = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq).reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq)
|
||||
new_amaxs.extend(amaxs)
|
||||
xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk)
|
||||
new_amaxs.extend(amaxs)
|
||||
xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv)
|
||||
new_amaxs.extend(amaxs)
|
||||
xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
@@ -131,14 +137,24 @@ class FlatTransformer:
|
||||
else:
|
||||
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
return matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo)
|
||||
|
||||
out, *amaxs = matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo)
|
||||
new_amaxs.extend(amaxs)
|
||||
return (out, *new_amaxs)
|
||||
|
||||
def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor,
|
||||
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
|
||||
x = rmsnorm(x, self.norm_eps) * ffn_norm
|
||||
x_w1 = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1).silu()
|
||||
x_w3 = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3)
|
||||
return matmul(x_w1 * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2)
|
||||
|
||||
new_amaxs = []
|
||||
|
||||
x_w1, *amaxs = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1)
|
||||
new_amaxs.extend(amaxs)
|
||||
x_w3, *amaxs = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3)
|
||||
new_amaxs.extend(amaxs)
|
||||
out, *amaxs = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2)
|
||||
new_amaxs.extend(amaxs)
|
||||
return (out, *new_amaxs)
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def run_layer(self, x:Tensor, freqs_cis:Tensor,
|
||||
@@ -148,13 +164,16 @@ class FlatTransformer:
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
|
||||
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None,
|
||||
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
|
||||
h = x + self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
|
||||
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
|
||||
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
|
||||
amax_xo=amax_xo, amax_wo=amax_wo)
|
||||
return h + self.feed_forward(h, ffn_norm, w1, w2, w3,
|
||||
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
|
||||
amax_x3=amax_x3, amax_w3=amax_w3)
|
||||
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
|
||||
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
|
||||
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
|
||||
amax_xo=amax_xo, amax_wo=amax_wo)
|
||||
h = x + attn
|
||||
ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3,
|
||||
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
|
||||
amax_x3=amax_x3, amax_w3=amax_w3)
|
||||
h = h + ffn
|
||||
return (h, *attn_amaxs, *ffn_amaxs)
|
||||
|
||||
def shard(self, device:tuple[str, ...], mp:bool=False):
|
||||
from tinygrad.nn.state import get_parameters
|
||||
@@ -196,11 +215,17 @@ class FlatTransformer:
|
||||
"amax_x1": a["x1"][i], "amax_w1": a["w1"][i],
|
||||
"amax_x2": a["x2"][i], "amax_w2": a["w2"][i],
|
||||
"amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {}
|
||||
h = self.run_layer(h, freqs_cis,
|
||||
self.attention_norm[i], self.wo[i],
|
||||
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
|
||||
**attn_kwargs, **amax_attn, **amax_layer)
|
||||
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False).contiguous_backward()
|
||||
h, *amaxs = self.run_layer(h, freqs_cis,
|
||||
self.attention_norm[i], self.wo[i],
|
||||
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
|
||||
**attn_kwargs, **amax_attn, **amax_layer)
|
||||
if a:
|
||||
if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
|
||||
else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
|
||||
for name, new_val in zip(amax_names, amaxs):
|
||||
a[name][i].assign(new_val)
|
||||
|
||||
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False)[0].contiguous_backward()
|
||||
return logits
|
||||
|
||||
def _get_pads(uop:UOp) -> list[UOp]:
|
||||
|
||||
@@ -2704,7 +2704,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
a_t, b_t, g_t, s_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device), Tensor(scale, device=a.device)
|
||||
g_t = g_t[:a.shape[0]]
|
||||
# backward GEMMs in fp8 with scale applied inside kernel to prevent bf16 overflow
|
||||
g_fp8, g_scale = quantize_fp8(g_t)
|
||||
g_fp8, g_scale, _ = quantize_fp8(g_t)
|
||||
bw_scale = g_scale * s_t
|
||||
# dgrad: g_fp8 @ weight (asm_gemm computes a@b)
|
||||
grad_a = asm_gemm(g_fp8, b_t, combined_scale=bw_scale)
|
||||
|
||||
Reference in New Issue
Block a user