diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 6373888088..300d86b6a6 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -25,28 +25,25 @@ FP8_GRAD_DTYPE = dtypes.fp8e5m2 FP8_MAX = 448.0 def quantize_fp8(x:Tensor, amax_state:Tensor|None=None): - if amax_state is not None: - scale = FP8_MAX / (amax_state + 1e-8) - amax_state.assign(x.abs().max().detach()) - else: - scale = FP8_MAX / (x.abs().max().detach() + 1e-8) + new_amax = x.abs().max().detach() + scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8) x_scaled = x * scale x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE - return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal() + return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax -def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor: +def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> tuple[Tensor,...]: if not fp8: if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm - if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T) - return x @ w.T - x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x) - w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w) + if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),) + return (x @ w.T,) + x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x) + w_fp8, w_scale, w_new_amax = quantize_fp8(w, amax_state=amax_w) combined_scale = x_scale * w_scale if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm - if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale) - return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale + if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale), x_new_amax, w_new_amax + return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale, x_new_amax, w_new_amax def rmsnorm(x_in:Tensor, eps:float): x = x_in.float() @@ -108,19 +105,28 @@ class FlatTransformer: amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None, amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None): x = rmsnorm(x, self.norm_eps) * attention_norm + bsz, seqlen, _ = x.shape + new_amaxs = [] if wqkv is not None: - xqkv = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv) + xqkv, *amaxs = matmul(x, wqkv, amax_x=amax_xqkv, amax_w=amax_wqkv) + new_amaxs.extend(amaxs) xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim) xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim) xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) else: assert wq is not None and wk is not None and wv is not None - xq = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq).reshape(bsz, seqlen, self.n_heads, self.head_dim) - xk = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) - xv = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv).reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq, *amaxs = matmul(x, wq, amax_x=amax_xq, amax_w=amax_wq) + new_amaxs.extend(amaxs) + xq = xq.reshape(bsz, seqlen, self.n_heads, self.head_dim) + xk, *amaxs = matmul(x, wk, amax_x=amax_xk, amax_w=amax_wk) + new_amaxs.extend(amaxs) + xk = xk.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv, *amaxs = matmul(x, wv, amax_x=amax_xv, amax_w=amax_wv) + new_amaxs.extend(amaxs) + xv = xv.reshape(bsz, seqlen, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis) if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16) @@ -131,14 +137,24 @@ class FlatTransformer: else: attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2) attn = attn.reshape(bsz, seqlen, -1) - return matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo) + + out, *amaxs = matmul(attn, wo, amax_x=amax_xo, amax_w=amax_wo) + new_amaxs.extend(amaxs) + return (out, *new_amaxs) def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor, amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None): x = rmsnorm(x, self.norm_eps) * ffn_norm - x_w1 = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1).silu() - x_w3 = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3) - return matmul(x_w1 * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2) + + new_amaxs = [] + + x_w1, *amaxs = matmul(x, w1, amax_x=amax_x1, amax_w=amax_w1) + new_amaxs.extend(amaxs) + x_w3, *amaxs = matmul(x.contiguous_backward(), w3, amax_x=amax_x3, amax_w=amax_w3) + new_amaxs.extend(amaxs) + out, *amaxs = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, amax_w=amax_w2) + new_amaxs.extend(amaxs) + return (out, *new_amaxs) @function(precompile=True, precompile_backward=True) def run_layer(self, x:Tensor, freqs_cis:Tensor, @@ -148,13 +164,16 @@ class FlatTransformer: amax_xqkv=None, amax_wqkv=None, amax_xq=None, amax_wq=None, amax_xk=None, amax_wk=None, amax_xv=None, amax_wv=None, amax_xo=None, amax_wo=None, amax_x1=None, amax_w1=None, amax_x2=None, amax_w2=None, amax_x3=None, amax_w3=None): - h = x + self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv, - amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq, - amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv, - amax_xo=amax_xo, amax_wo=amax_wo) - return h + self.feed_forward(h, ffn_norm, w1, w2, w3, - amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, - amax_x3=amax_x3, amax_w3=amax_w3) + attn, *attn_amaxs = self.attention(x, freqs_cis, attention_norm, wo, wqkv=wqkv, wq=wq, wk=wk, wv=wv, + amax_xqkv=amax_xqkv, amax_wqkv=amax_wqkv, amax_xq=amax_xq, amax_wq=amax_wq, + amax_xk=amax_xk, amax_wk=amax_wk, amax_xv=amax_xv, amax_wv=amax_wv, + amax_xo=amax_xo, amax_wo=amax_wo) + h = x + attn + ffn, *ffn_amaxs = self.feed_forward(h, ffn_norm, w1, w2, w3, + amax_x1=amax_x1, amax_w1=amax_w1, amax_x2=amax_x2, amax_w2=amax_w2, + amax_x3=amax_x3, amax_w3=amax_w3) + h = h + ffn + return (h, *attn_amaxs, *ffn_amaxs) def shard(self, device:tuple[str, ...], mp:bool=False): from tinygrad.nn.state import get_parameters @@ -196,11 +215,17 @@ class FlatTransformer: "amax_x1": a["x1"][i], "amax_w1": a["w1"][i], "amax_x2": a["x2"][i], "amax_w2": a["w2"][i], "amax_x3": a["x3"][i], "amax_w3": a["w3"][i]} if a else {} - h = self.run_layer(h, freqs_cis, - self.attention_norm[i], self.wo[i], - self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i], - **attn_kwargs, **amax_attn, **amax_layer) - logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False).contiguous_backward() + h, *amaxs = self.run_layer(h, freqs_cis, + self.attention_norm[i], self.wo[i], + self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i], + **attn_kwargs, **amax_attn, **amax_layer) + if a: + if WQKV: amax_names = ["xqkv", "wqkv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"] + else: amax_names = ["xq", "wq", "xk", "wk", "xv", "wv", "xo", "wo", "x1", "w1", "x3", "w3", "x2", "w2"] + for name, new_val in zip(amax_names, amaxs): + a[name][i].assign(new_val) + + logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False)[0].contiguous_backward() return logits def _get_pads(uop:UOp) -> list[UOp]: diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index cbeb8995df..810d261c37 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2704,7 +2704,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): a_t, b_t, g_t, s_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device), Tensor(scale, device=a.device) g_t = g_t[:a.shape[0]] # backward GEMMs in fp8 with scale applied inside kernel to prevent bf16 overflow - g_fp8, g_scale = quantize_fp8(g_t) + g_fp8, g_scale, _ = quantize_fp8(g_t) bw_scale = g_scale * s_t # dgrad: g_fp8 @ weight (asm_gemm computes a@b) grad_a = asm_gemm(g_fp8, b_t, combined_scale=bw_scale)