flat llama step work (#15355)

* flat llama step work

* fp8 support

* blacklisted matmul

* chestertons fence
This commit is contained in:
George Hotz
2026-03-20 09:06:12 +08:00
committed by GitHub
parent 176ad47d7d
commit 4091d37e8e
3 changed files with 60 additions and 16 deletions

View File

@@ -2,7 +2,7 @@ import math, os
if __name__ == "__main__":
os.environ["DEFAULT_FLOAT"] = "bfloat16"
os.environ["OPTIM_DTYPE"] = "bfloat16"
os.environ["DEV"] = "NULL"
if "DEV" not in os.environ: os.environ["DEV"] = "NULL"
# CDNA
os.environ["EMULATE"] = "AMD_CDNA4"
os.environ["DEVICE_IN_FUNCTION_BUG"] = "1"
@@ -13,10 +13,27 @@ if __name__ == "__main__":
if "ASM_GEMM" not in os.environ:
os.environ["ASM_GEMM"] = "1"
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
FP8 = getenv("FP8", 0)
FP8_DTYPE = dtypes.fp8e4m3
FP8_MAX = 448.0
def quantize_fp8(x:Tensor):
scale = FP8_MAX / (x.abs().max().detach() + 1e-8)
x_scaled = x * scale
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
def matmul(x:Tensor, w:Tensor) -> Tensor:
if not FP8: return x @ w.T
# weights are already FP8, just quantize activations
x_fp8, x_scale = quantize_fp8(x)
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale
def rmsnorm(x_in:Tensor, eps:float):
x = x_in.float()
x = x * (x.square().mean(-1, keepdim=True) + eps).rsqrt()
@@ -53,12 +70,13 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int):
bound = 1 / math.sqrt(in_features)
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features)
return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound)
dt = FP8_DTYPE if FP8 else None
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features, dtype=dt)
return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound, dtype=dt)
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor):
x = rmsnorm(x, self.norm_eps) * attention_norm
xqkv = x @ wqkv.T
xqkv = matmul(x, wqkv)
bsz, seqlen, _ = xqkv.shape
# interleaved layout: each kv group has [n_rep q heads, 1 k head, 1 v head] for clean MP sharding
@@ -71,13 +89,13 @@ class FlatTransformer:
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return attn @ wo.T
return matmul(attn, wo)
def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor):
x = rmsnorm(x, self.norm_eps) * ffn_norm
x_w1 = (x @ w1.T).silu()
x_w3 = x.contiguous_backward() @ w3.T
return (x_w1 * x_w3) @ w2.T
x_w1 = matmul(x, w1).silu()
x_w3 = matmul(x.contiguous_backward(), w3)
return matmul(x_w1 * x_w3, w2)
@function(precompile=True, precompile_backward=True)
def run_layer(self, x:Tensor, freqs_cis:Tensor,
@@ -139,7 +157,7 @@ if __name__ == "__main__":
# print model size
sz = 0
for k,v in state.items():
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB")
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {str(v.dtype):20s} {v.device} {v.nbytes()/1e9:.2f} GB")
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
@@ -161,15 +179,15 @@ if __name__ == "__main__":
@TinyJit
def jit_step(tokens:Tensor):
GlobalCounters.reset()
print(colored("*** step", "red"))
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python backward: "):
for t,g in zip(grads, loss.gradient(*grads)):
grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device)
with Timing("run step: "): loss.realize(*grads.values())
jit_step(tokens)
jit_step(tokens)
jit_step(tokens)
for i in range(6):
GlobalCounters.reset()
profile_marker(f"step {i}")
with Timing(colored(f"*** step {i}: ", "red")):
jit_step(tokens)
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))

View File

@@ -2,8 +2,9 @@ import os
os.environ["WQKV"] = "1"
import unittest
import numpy as np
from tinygrad import Tensor, nn
from tinygrad import Tensor, nn, dtypes
from tinygrad.nn.state import get_parameters
from tinygrad.device import is_dtype_supported
from examples.mlperf.models.llama import Transformer
from examples.mlperf.models.flat_llama import FlatTransformer
@@ -113,5 +114,27 @@ class TestFlatLlama(unittest.TestCase):
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8 not supported on this device")
def test_forward_fp8(self):
import examples.mlperf.models.flat_llama as flat_llama_mod
old_fp8 = flat_llama_mod.FP8
try:
flat_llama_mod.FP8 = 1
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
tokens = Tensor([[1, 50, 100, 999, 2]])
ref_logits = ref(tokens).numpy()
flat_logits = flat(tokens).numpy()
self.assertEqual(ref_logits.shape, flat_logits.shape)
# FP8 has lower precision, allow larger tolerance
np.testing.assert_allclose(flat_logits, ref_logits, atol=1.0, rtol=0.1)
finally:
flat_llama_mod.FP8 = old_fp8
if __name__ == "__main__":
unittest.main()

View File

@@ -2649,6 +2649,9 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
else: dname = a.device
arch = getattr(Device[dname].renderer, "arch", "")
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
# blacklist slow matmul
# TODO: why is this slow?
if (M,N,K) == (8192, 2304, 16384): return todo("blacklisted slow matmul")
if (M % TILE_M != 0 or N % TILE_N != 0 or K % TILE_K != 0) and arch == "gfx950":
return todo(f"GEMM shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})")
return True