diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 300d86b6a6..f164baa492 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -18,7 +18,6 @@ from tinygrad.uop.ops import Ops, UOp from extra.models.llama import apply_rotary_emb, precompute_freqs_cis FP8 = getenv("FP8", 0) -WQKV = getenv("WQKV", 0) FP8_DTYPE = dtypes.fp8e4m3 FP8_GRAD_DTYPE = dtypes.fp8e5m2 @@ -63,12 +62,7 @@ class FlatTransformer: scaled_std = 0.02 / math.sqrt(2 * n_layers) # Attention - if WQKV: - self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2) - else: - self.wq = self.lin_per_layer(dim, self.n_heads * self.head_dim) - self.wk = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim) - self.wv = self.lin_per_layer(dim, self.n_kv_heads * self.head_dim) + self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2) self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim, std=scaled_std) # FeedForward @@ -89,8 +83,7 @@ class FlatTransformer: if FP8: def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False) - names = (["xqkv", "wqkv"] if WQKV else ["xq", "wq", "xk", "wk", "xv", "wv"]) + \ - ["xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"] + names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x2", "w2", "x3", "w3"] # _fp8_amax[name][layer_idx] = scalar amax tensor self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names} self._fp8_amax["xout"] = [_amax()] @@ -100,40 +93,26 @@ class FlatTransformer: if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features) return Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std) - def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wo:Tensor, wqkv:Tensor|None=None, - wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None, - amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None, - amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None): + def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor, + amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None): x = rmsnorm(x, self.norm_eps) * attention_norm bsz, seqlen, _ = x.shape new_amaxs = [] - if wqkv is not None: - xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv) - new_amaxs.extend(amaxs) - xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim) - xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim) - xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) - xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) - else: - assert wq is not None and wk is not None and wv is not None - xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq) - new_amaxs.extend(amaxs) - xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim) - xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk) - new_amaxs.extend(amaxs) - xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) - xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv) - new_amaxs.extend(amaxs) - xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) + xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv) + new_amaxs.extend(amaxs) + xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim) + xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim) + xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis) if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16) xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2) if getenv("HK_FLASH_ATTENTION"): from extra.thunder.amd.fa import flash_attention - attn = flash_attention(xq, xk, xv, is_causal=True) + attn = flash_attention(xq, xk, xv, is_causal=True).transpose(1, 2) else: attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2) attn = attn.reshape(bsz, seqlen, -1) @@ -158,20 +137,15 @@ class FlatTransformer: @function(precompile=True, precompile_backward=True) def run_layer(self, x:Tensor, freqs_cis:Tensor, - attention_norm:Tensor, wo:Tensor, + attention_norm:Tensor, wqkv:Tensor, wo:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor, - wqkv:Tensor|None=None, wq:Tensor|None=None, wk:Tensor|None=None, wv:Tensor|None=None, - amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None, - amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None, + amax_xqkv=None, amax_wqkv=None, amax_xo=None, amax_wo=None, amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None): - attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv, - amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq, - amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv, - amax_xo=amax_xo, amax_wo=amax_wo) + attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wqkv, wo, + amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xo=amax_xo, amax_wo=amax_wo) h = x + attn ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3, - amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, - amax_x3=amax_x3, amax_w3=amax_w3) + amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, amax_x3=amax_x3, amax_w3=amax_w3) h = h + ffn return (h, *attn_amaxs, *ffn_amaxs) @@ -181,12 +155,7 @@ class FlatTransformer: for v in get_parameters(self): v.shard_(device, axis=None) else: # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer - if WQKV: - self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out - else: - self.wq.shard_(device, axis=1).realize() # (n_layers, n_heads*head_dim, dim) shard out - self.wk.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out - self.wv.shard_(device, axis=1).realize() # (n_layers, n_kv_heads*head_dim, dim) shard out + self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in self.w1.shard_(device, axis=1).realize() # (n_layers, hidden, dim) shard out self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in @@ -203,25 +172,17 @@ class FlatTransformer: freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :] a = self._fp8_amax if FP8 else None for i in range(self.n_layers): - if WQKV: - attn_kwargs = {"wqkv": self.wqkv[i]} - amax_attn = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i]} if a else {} - else: - attn_kwargs = {"wq": self.wq[i], "wk": self.wk[i], "wv": self.wv[i]} - amax_attn = {"amax_xq": a["xq"][i], "amax_wq": a["wq"][i], - "amax_xk": a["xk"][i], "amax_wk": a["wk"][i], - "amax_xv": a["xv"][i], "amax_wv": a["wv"][i]} if a else {} - amax_layer = {"amax_xo": a["xo"][i], "amax_wo": a["wo"][i], + amax_layer = {"amax_xqkv": a["xqkv"][i], "amax_wqkv": a["wqkv"][i], + "amax_xo": a["xo"][i], "amax_wo": a["wo"][i], "amax_x1": a["x1"][i], "amax_w1": a["w1"][i], "amax_x2": a["x2"][i], "amax_w2": a["w2"][i], "amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {} h, *amaxs = self.run_layer(h, freqs_cis, - self.attention_norm[i], self.wo[i], + self.attention_norm[i], self.wqkv[i], self.wo[i], self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i], - **attn_kwargs, **amax_attn, **amax_layer) + **amax_layer) if a: - if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"] - else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"] + amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"] for name, new_val in zip(amax_names, amaxs): a[name][i].assign(new_val)