llama: only support wqkv path + cleanups (#15680)

* llama: only support wqkv path + cleanups

* llama: missing transpose
This commit is contained in:
wozeparrot
2026-04-11 07:39:27 +08:00
committed by GitHub
parent aa012d6f08
commit 590464c8d8

View File

@@ -18,7 +18,6 @@ from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
FP8 = getenv("FP8", 0)
WQKV = getenv("WQKV", 0)
FP8_DTYPE = dtypes.fp8e4m3
FP8_GRAD_DTYPE = dtypes.fp8e5m2
@@ -63,12 +62,7 @@ class FlatTransformer:
scaled_std = 0.02 / math.sqrt(2 * n_layers)
# Attention
if WQKV:
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
else:
self.wq = self.lin_per_layer(dim, self.n_heads * self.head_dim)
self.wk = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim)
self.wv = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim)
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim, std=scaled_std)
# FeedForward
@@ -89,8 +83,7 @@ class FlatTransformer:
if FP8:
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
names = (["xqkv", "wqkv"] if WQKV else ["xq", "wq", "xk", "wk", "xv", "wv"]) + \
["xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"]
names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"]
# _fp8_amax[name][layer_idx] = scalar amax tensor
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
self._fp8_amax["xout"] = [_amax()]
@@ -100,40 +93,26 @@ class FlatTransformer:
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features)
return Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wo:Tensor, wqkv:Tensor|None=None,
wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None,
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None):
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor,
amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None):
x = rmsnorm(x, self.norm_eps) * attention_norm
bsz, seqlen, _ = x.shape
new_amaxs = []
if wqkv is not None:
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
new_amaxs.extend(amaxs)
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
else:
assert wq is not None and wk is not None and wv is not None
xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq)
new_amaxs.extend(amaxs)
xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk)
new_amaxs.extend(amaxs)
xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv)
new_amaxs.extend(amaxs)
xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
new_amaxs.extend(amaxs)
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
if getenv("HK_FLASH_ATTENTION"):
from extra.thunder.amd.fa import flash_attention
attn = flash_attention(xq, xk, xv, is_causal=True)
attn = flash_attention(xq, xk, xv, is_causal=True).transpose(1, 2)
else:
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
@@ -158,20 +137,15 @@ class FlatTransformer:
@function(precompile=True, precompile_backward=True)
def run_layer(self, x:Tensor, freqs_cis:Tensor,
attention_norm:Tensor, wo:Tensor,
attention_norm:Tensor, wqkv:Tensor, wo:Tensor,
ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor,
wqkv:Tensor|None=None, wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None,
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None,
amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None,
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
amax_xo=amax_xo, amax_wo=amax_wo)
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wqkv, wo,
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xo=amax_xo, amax_wo=amax_wo)
h = x + attn
ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3,
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
amax_x3=amax_x3, amax_w3=amax_w3)
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, amax_x3=amax_x3, amax_w3=amax_w3)
h = h + ffn
return (h, *attn_amaxs, *ffn_amaxs)
@@ -181,12 +155,7 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
if WQKV:
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
else:
self.wq.shard_(device, axis=1).realize() # (n_layers, n_heads*head_dim, dim) shard out
self.wk.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out
self.wv.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
self.w1.shard_(device, axis=1).realize() # (n_layers, hidden, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
@@ -203,25 +172,17 @@ class FlatTransformer:
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
a = self._fp8_amax if FP8 else None
for i in range(self.n_layers):
if WQKV:
attn_kwargs = {"wqkv": self.wqkv[i]}
amax_attn = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i]} if a else {}
else:
attn_kwargs = {"wq": self.wq[i], "wk": self.wk[i], "wv": self.wv[i]}
amax_attn = {"amax_xq": a["xq"][i], "amax_wq": a["wq"][i],
"amax_xk": a["xk"][i], "amax_wk": a["wk"][i],
"amax_xv": a["xv"][i], "amax_wv": a["wv"][i]} if a else {}
amax_layer = {"amax_xo": a["xo"][i], "amax_wo": a["wo"][i],
amax_layer = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i],
"amax_xo": a["xo"][i], "amax_wo": a["wo"][i],
"amax_x1": a["x1"][i], "amax_w1": a["w1"][i],
"amax_x2": a["x2"][i], "amax_w2": a["w2"][i],
"amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {}
h, *amaxs = self.run_layer(h, freqs_cis,
self.attention_norm[i], self.wo[i],
self.attention_norm[i], self.wqkv[i], self.wo[i],
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
**attn_kwargs, **amax_attn, **amax_layer)
**amax_layer)
if a:
if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
for name, new_val in zip(amax_names, amaxs):
a[name][i].assign(new_val)