mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
flat llama step work (#15355)
* flat llama step work * fp8 support * blacklisted matmul * chestertons fence
This commit is contained in:
@@ -2,7 +2,7 @@ import math, os
|
||||
if __name__ == "__main__":
|
||||
os.environ["DEFAULT_FLOAT"] = "bfloat16"
|
||||
os.environ["OPTIM_DTYPE"] = "bfloat16"
|
||||
os.environ["DEV"] = "NULL"
|
||||
if "DEV" not in os.environ: os.environ["DEV"] = "NULL"
|
||||
# CDNA
|
||||
os.environ["EMULATE"] = "AMD_CDNA4"
|
||||
os.environ["DEVICE_IN_FUNCTION_BUG"] = "1"
|
||||
@@ -13,10 +13,27 @@ if __name__ == "__main__":
|
||||
if "ASM_GEMM" not in os.environ:
|
||||
os.environ["ASM_GEMM"] = "1"
|
||||
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
|
||||
from tinygrad.helpers import Timing, colored, GlobalCounters
|
||||
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
|
||||
|
||||
FP8 = getenv("FP8", 0)
|
||||
|
||||
FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_MAX = 448.0
|
||||
|
||||
def quantize_fp8(x:Tensor):
|
||||
scale = FP8_MAX / (x.abs().max().detach() + 1e-8)
|
||||
x_scaled = x * scale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
|
||||
|
||||
def matmul(x:Tensor, w:Tensor) -> Tensor:
|
||||
if not FP8: return x @ w.T
|
||||
# weights are already FP8, just quantize activations
|
||||
x_fp8, x_scale = quantize_fp8(x)
|
||||
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale
|
||||
|
||||
def rmsnorm(x_in:Tensor, eps:float):
|
||||
x = x_in.float()
|
||||
x = x * (x.square().mean(-1, keepdim=True) + eps).rsqrt()
|
||||
@@ -53,12 +70,13 @@ class FlatTransformer:
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int):
|
||||
bound = 1 / math.sqrt(in_features)
|
||||
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound)
|
||||
dt = FP8_DTYPE if FP8 else None
|
||||
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features, dtype=dt)
|
||||
return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound, dtype=dt)
|
||||
|
||||
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor):
|
||||
x = rmsnorm(x, self.norm_eps) * attention_norm
|
||||
xqkv = x @ wqkv.T
|
||||
xqkv = matmul(x, wqkv)
|
||||
|
||||
bsz, seqlen, _ = xqkv.shape
|
||||
# interleaved layout: each kv group has [n_rep q heads, 1 k head, 1 v head] for clean MP sharding
|
||||
@@ -71,13 +89,13 @@ class FlatTransformer:
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
return attn @ wo.T
|
||||
return matmul(attn, wo)
|
||||
|
||||
def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor):
|
||||
x = rmsnorm(x, self.norm_eps) * ffn_norm
|
||||
x_w1 = (x @ w1.T).silu()
|
||||
x_w3 = x.contiguous_backward() @ w3.T
|
||||
return (x_w1 * x_w3) @ w2.T
|
||||
x_w1 = matmul(x, w1).silu()
|
||||
x_w3 = matmul(x.contiguous_backward(), w3)
|
||||
return matmul(x_w1 * x_w3, w2)
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def run_layer(self, x:Tensor, freqs_cis:Tensor,
|
||||
@@ -139,7 +157,7 @@ if __name__ == "__main__":
|
||||
# print model size
|
||||
sz = 0
|
||||
for k,v in state.items():
|
||||
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB")
|
||||
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {str(v.dtype):20s} {v.device} {v.nbytes()/1e9:.2f} GB")
|
||||
sz += v.nbytes()
|
||||
print(f"total sz: {sz/1e9:.2f} GB")
|
||||
|
||||
@@ -161,15 +179,15 @@ if __name__ == "__main__":
|
||||
|
||||
@TinyJit
|
||||
def jit_step(tokens:Tensor):
|
||||
GlobalCounters.reset()
|
||||
print(colored("*** step", "red"))
|
||||
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
with Timing("python backward: "):
|
||||
for t,g in zip(grads, loss.gradient(*grads)):
|
||||
grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device)
|
||||
with Timing("run step: "): loss.realize(*grads.values())
|
||||
|
||||
jit_step(tokens)
|
||||
jit_step(tokens)
|
||||
jit_step(tokens)
|
||||
for i in range(6):
|
||||
GlobalCounters.reset()
|
||||
profile_marker(f"step {i}")
|
||||
with Timing(colored(f"*** step {i}: ", "red")):
|
||||
jit_step(tokens)
|
||||
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
|
||||
|
||||
@@ -2,8 +2,9 @@ import os
|
||||
os.environ["WQKV"] = "1"
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from examples.mlperf.models.llama import Transformer
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer
|
||||
|
||||
@@ -113,5 +114,27 @@ class TestFlatLlama(unittest.TestCase):
|
||||
self.assertEqual(ref_logits.shape, flat_logits.shape)
|
||||
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8 not supported on this device")
|
||||
def test_forward_fp8(self):
|
||||
import examples.mlperf.models.flat_llama as flat_llama_mod
|
||||
old_fp8 = flat_llama_mod.FP8
|
||||
try:
|
||||
flat_llama_mod.FP8 = 1
|
||||
Tensor.manual_seed(42)
|
||||
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
|
||||
ref = Transformer(**params)
|
||||
flat = FlatTransformer(**params)
|
||||
copy_weights(flat, ref)
|
||||
Tensor.realize(*nn.state.get_state_dict(flat).values())
|
||||
|
||||
tokens = Tensor([[1, 50, 100, 999, 2]])
|
||||
ref_logits = ref(tokens).numpy()
|
||||
flat_logits = flat(tokens).numpy()
|
||||
self.assertEqual(ref_logits.shape, flat_logits.shape)
|
||||
# FP8 has lower precision, allow larger tolerance
|
||||
np.testing.assert_allclose(flat_logits, ref_logits, atol=1.0, rtol=0.1)
|
||||
finally:
|
||||
flat_llama_mod.FP8 = old_fp8
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -2649,6 +2649,9 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
|
||||
else: dname = a.device
|
||||
arch = getattr(Device[dname].renderer, "arch", "")
|
||||
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
|
||||
# blacklist slow matmul
|
||||
# TODO: why is this slow?
|
||||
if (M,N,K) == (8192, 2304, 16384): return todo("blacklisted slow matmul")
|
||||
if (M % TILE_M != 0 or N % TILE_N != 0 or K % TILE_K != 0) and arch == "gfx950":
|
||||
return todo(f"GEMM shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user