diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index e5c235e45e..b293c31203 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -2,7 +2,7 @@ import math, os if __name__ == "__main__": os.environ["DEFAULT_FLOAT"] = "bfloat16" os.environ["OPTIM_DTYPE"] = "bfloat16" - os.environ["DEV"] = "NULL" + if "DEV" not in os.environ: os.environ["DEV"] = "NULL" # CDNA os.environ["EMULATE"] = "AMD_CDNA4" os.environ["DEVICE_IN_FUNCTION_BUG"] = "1" @@ -13,10 +13,27 @@ if __name__ == "__main__": if "ASM_GEMM" not in os.environ: os.environ["ASM_GEMM"] = "1" from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit -from tinygrad.helpers import Timing, colored, GlobalCounters +from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker from tinygrad.uop.ops import Ops, UOp from extra.models.llama import apply_rotary_emb, precompute_freqs_cis +FP8 = getenv("FP8", 0) + +FP8_DTYPE = dtypes.fp8e4m3 +FP8_MAX = 448.0 + +def quantize_fp8(x:Tensor): + scale = FP8_MAX / (x.abs().max().detach() + 1e-8) + x_scaled = x * scale + x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE + return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal() + +def matmul(x:Tensor, w:Tensor) -> Tensor: + if not FP8: return x @ w.T + # weights are already FP8, just quantize activations + x_fp8, x_scale = quantize_fp8(x) + return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale + def rmsnorm(x_in:Tensor, eps:float): x = x_in.float() x = x * (x.square().mean(-1, keepdim=True) + eps).rsqrt() @@ -53,12 +70,13 @@ class FlatTransformer: def lin_per_layer(self, in_features:int, out_features:int): bound = 1 / math.sqrt(in_features) - if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features) - return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound) + dt = FP8_DTYPE if FP8 else None + if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features, dtype=dt) + return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound, dtype=dt) def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor): x = rmsnorm(x, self.norm_eps) * attention_norm - xqkv = x @ wqkv.T + xqkv = matmul(x, wqkv) bsz, seqlen, _ = xqkv.shape # interleaved layout: each kv group has [n_rep q heads, 1 k head, 1 v head] for clean MP sharding @@ -71,13 +89,13 @@ class FlatTransformer: xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2) attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2) attn = attn.reshape(bsz, seqlen, -1) - return attn @ wo.T + return matmul(attn, wo) def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor): x = rmsnorm(x, self.norm_eps) * ffn_norm - x_w1 = (x @ w1.T).silu() - x_w3 = x.contiguous_backward() @ w3.T - return (x_w1 * x_w3) @ w2.T + x_w1 = matmul(x, w1).silu() + x_w3 = matmul(x.contiguous_backward(), w3) + return matmul(x_w1 * x_w3, w2) @function(precompile=True, precompile_backward=True) def run_layer(self, x:Tensor, freqs_cis:Tensor, @@ -139,7 +157,7 @@ if __name__ == "__main__": # print model size sz = 0 for k,v in state.items(): - print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB") + print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {str(v.dtype):20s} {v.device} {v.nbytes()/1e9:.2f} GB") sz += v.nbytes() print(f"total sz: {sz/1e9:.2f} GB") @@ -161,15 +179,15 @@ if __name__ == "__main__": @TinyJit def jit_step(tokens:Tensor): - GlobalCounters.reset() - print(colored("*** step", "red")) with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:]) with Timing("python backward: "): for t,g in zip(grads, loss.gradient(*grads)): grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device) with Timing("run step: "): loss.realize(*grads.values()) - jit_step(tokens) - jit_step(tokens) - jit_step(tokens) + for i in range(6): + GlobalCounters.reset() + profile_marker(f"step {i}") + with Timing(colored(f"*** step {i}: ", "red")): + jit_step(tokens) print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items()))) diff --git a/examples/mlperf/models/test_flat_llama.py b/examples/mlperf/models/test_flat_llama.py index fe99de4ac7..d87a94edd1 100644 --- a/examples/mlperf/models/test_flat_llama.py +++ b/examples/mlperf/models/test_flat_llama.py @@ -2,8 +2,9 @@ import os os.environ["WQKV"] = "1" import unittest import numpy as np -from tinygrad import Tensor, nn +from tinygrad import Tensor, nn, dtypes from tinygrad.nn.state import get_parameters +from tinygrad.device import is_dtype_supported from examples.mlperf.models.llama import Transformer from examples.mlperf.models.flat_llama import FlatTransformer @@ -113,5 +114,27 @@ class TestFlatLlama(unittest.TestCase): self.assertEqual(ref_logits.shape, flat_logits.shape) np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4) + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8 not supported on this device") + def test_forward_fp8(self): + import examples.mlperf.models.flat_llama as flat_llama_mod + old_fp8 = flat_llama_mod.FP8 + try: + flat_llama_mod.FP8 = 1 + Tensor.manual_seed(42) + params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64) + ref = Transformer(**params) + flat = FlatTransformer(**params) + copy_weights(flat, ref) + Tensor.realize(*nn.state.get_state_dict(flat).values()) + + tokens = Tensor([[1, 50, 100, 999, 2]]) + ref_logits = ref(tokens).numpy() + flat_logits = flat(tokens).numpy() + self.assertEqual(ref_logits.shape, flat_logits.shape) + # FP8 has lower precision, allow larger tolerance + np.testing.assert_allclose(flat_logits, ref_logits, atol=1.0, rtol=0.1) + finally: + flat_llama_mod.FP8 = old_fp8 + if __name__ == "__main__": unittest.main() diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 78947853f7..6b9f373154 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2649,6 +2649,9 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: else: dname = a.device arch = getattr(Device[dname].renderer, "arch", "") if batch not in {1, 2}: return todo(f"GEMM batch size {batch}") + # blacklist slow matmul + # TODO: why is this slow? + if (M,N,K) == (8192, 2304, 16384): return todo("blacklisted slow matmul") if (M % TILE_M != 0 or N % TILE_N != 0 or K % TILE_K != 0) and arch == "gfx950": return todo(f"GEMM shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})") return True