mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama: only support wqkv path + cleanups (#15680)
* llama: only support wqkv path + cleanups * llama: missing transpose
This commit is contained in:
@@ -18,7 +18,6 @@ from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
|
||||
|
||||
FP8 = getenv("FP8", 0)
|
||||
WQKV = getenv("WQKV", 0)
|
||||
|
||||
FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
@@ -63,12 +62,7 @@ class FlatTransformer:
|
||||
scaled_std = 0.02 / math.sqrt(2 * n_layers)
|
||||
|
||||
# Attention
|
||||
if WQKV:
|
||||
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
|
||||
else:
|
||||
self.wq = self.lin_per_layer(dim, self.n_heads * self.head_dim)
|
||||
self.wk = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim)
|
||||
self.wv = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim)
|
||||
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
|
||||
self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim, std=scaled_std)
|
||||
|
||||
# FeedForward
|
||||
@@ -89,8 +83,7 @@ class FlatTransformer:
|
||||
|
||||
if FP8:
|
||||
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
|
||||
names = (["xqkv", "wqkv"] if WQKV else ["xq", "wq", "xk", "wk", "xv", "wv"]) + \
|
||||
["xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"]
|
||||
names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"]
|
||||
# _fp8_amax[name][layer_idx] = scalar amax tensor
|
||||
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
|
||||
self._fp8_amax["xout"] = [_amax()]
|
||||
@@ -100,40 +93,26 @@ class FlatTransformer:
|
||||
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
return Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
|
||||
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wo:Tensor, wqkv:Tensor|None=None,
|
||||
wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None,
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
|
||||
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None):
|
||||
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor,
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None):
|
||||
x = rmsnorm(x, self.norm_eps) * attention_norm
|
||||
|
||||
bsz, seqlen, _ = x.shape
|
||||
new_amaxs = []
|
||||
|
||||
if wqkv is not None:
|
||||
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
|
||||
new_amaxs.extend(amaxs)
|
||||
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
|
||||
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
else:
|
||||
assert wq is not None and wk is not None and wv is not None
|
||||
xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq)
|
||||
new_amaxs.extend(amaxs)
|
||||
xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk)
|
||||
new_amaxs.extend(amaxs)
|
||||
xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv)
|
||||
new_amaxs.extend(amaxs)
|
||||
xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv)
|
||||
new_amaxs.extend(amaxs)
|
||||
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
|
||||
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
if getenv("HK_FLASH_ATTENTION"):
|
||||
from extra.thunder.amd.fa import flash_attention
|
||||
attn = flash_attention(xq, xk, xv, is_causal=True)
|
||||
attn = flash_attention(xq, xk, xv, is_causal=True).transpose(1, 2)
|
||||
else:
|
||||
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
@@ -158,20 +137,15 @@ class FlatTransformer:
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def run_layer(self, x:Tensor, freqs_cis:Tensor,
|
||||
attention_norm:Tensor, wo:Tensor,
|
||||
attention_norm:Tensor, wqkv:Tensor, wo:Tensor,
|
||||
ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor,
|
||||
wqkv:Tensor|None=None, wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None,
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None,
|
||||
amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None,
|
||||
amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None,
|
||||
amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None):
|
||||
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv,
|
||||
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq,
|
||||
amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv,
|
||||
amax_xo=amax_xo, amax_wo=amax_wo)
|
||||
attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wqkv, wo,
|
||||
amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xo=amax_xo, amax_wo=amax_wo)
|
||||
h = x + attn
|
||||
ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3,
|
||||
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2,
|
||||
amax_x3=amax_x3, amax_w3=amax_w3)
|
||||
amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, amax_x3=amax_x3, amax_w3=amax_w3)
|
||||
h = h + ffn
|
||||
return (h, *attn_amaxs, *ffn_amaxs)
|
||||
|
||||
@@ -181,12 +155,7 @@ class FlatTransformer:
|
||||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
if WQKV:
|
||||
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
|
||||
else:
|
||||
self.wq.shard_(device, axis=1).realize() # (n_layers, n_heads*head_dim, dim) shard out
|
||||
self.wk.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out
|
||||
self.wv.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out
|
||||
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
|
||||
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
|
||||
self.w1.shard_(device, axis=1).realize() # (n_layers, hidden, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
@@ -203,25 +172,17 @@ class FlatTransformer:
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
|
||||
a = self._fp8_amax if FP8 else None
|
||||
for i in range(self.n_layers):
|
||||
if WQKV:
|
||||
attn_kwargs = {"wqkv": self.wqkv[i]}
|
||||
amax_attn = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i]} if a else {}
|
||||
else:
|
||||
attn_kwargs = {"wq": self.wq[i], "wk": self.wk[i], "wv": self.wv[i]}
|
||||
amax_attn = {"amax_xq": a["xq"][i], "amax_wq": a["wq"][i],
|
||||
"amax_xk": a["xk"][i], "amax_wk": a["wk"][i],
|
||||
"amax_xv": a["xv"][i], "amax_wv": a["wv"][i]} if a else {}
|
||||
amax_layer = {"amax_xo": a["xo"][i], "amax_wo": a["wo"][i],
|
||||
amax_layer = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i],
|
||||
"amax_xo": a["xo"][i], "amax_wo": a["wo"][i],
|
||||
"amax_x1": a["x1"][i], "amax_w1": a["w1"][i],
|
||||
"amax_x2": a["x2"][i], "amax_w2": a["w2"][i],
|
||||
"amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {}
|
||||
h, *amaxs = self.run_layer(h, freqs_cis,
|
||||
self.attention_norm[i], self.wo[i],
|
||||
self.attention_norm[i], self.wqkv[i], self.wo[i],
|
||||
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i],
|
||||
**attn_kwargs, **amax_attn, **amax_layer)
|
||||
**amax_layer)
|
||||
if a:
|
||||
if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
|
||||
else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
|
||||
amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"]
|
||||
for name, new_val in zip(amax_names, amaxs):
|
||||
a[name][i].assign(new_val)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user