flat llama (#15324)

* FlatTransformer

* works

* pass in buffer views

* print stuff

* print

* bugfixes
This commit is contained in:
George Hotz
2026-03-17 19:39:55 +08:00
committed by GitHub
parent 0a641ce17d
commit 2605840ee2

View File

@@ -0,0 +1,125 @@
import math, os
if __name__ == "__main__":
os.environ["DEFAULT_FLOAT"] = "bfloat16"
os.environ["OPTIM_DTYPE"] = "bfloat16"
os.environ["DEV"] = "NULL"
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
def rmsnorm(x_in:Tensor, eps:float):
x = x_in.float()
x = x * (x.square().mean(-1, keepdim=True) + eps).rsqrt()
return x.cast(x_in.dtype)
class FlatTransformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size:int, n_kv_heads:int|None=None,
rope_theta:int=10000, max_context:int=1024):
self.vocab_size = vocab_size
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
# Attention
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim)
# FeedForward
self.w1 = self.lin_per_layer(dim, hidden_dim)
self.w2 = self.lin_per_layer(hidden_dim, dim)
self.w3 = self.lin_per_layer(dim, hidden_dim)
self.norm_eps = norm_eps
self.attention_norm = Tensor.ones(n_layers, dim)
self.ffn_norm = Tensor.ones(n_layers, dim)
# output
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_context * 2, rope_theta).contiguous().requires_grad_(False)
def lin_per_layer(self, in_features:int, out_features:int):
bound = 1 / math.sqrt(in_features)
if getenv("ZEROS"): return Tensor.zeros(self.n_layers, out_features, in_features)
return Tensor.uniform(self.n_layers, out_features, in_features, low=-bound, high=bound)
def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tensor, wo:Tensor):
x = rmsnorm(x, self.norm_eps) * attention_norm
xqkv = x @ wqkv.T
# reshapes
xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, _, _ = xq.shape
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return attn @ wo.T
def feed_forward(self, x:Tensor, ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor):
x = rmsnorm(x, self.norm_eps) * ffn_norm
x_w1 = (x @ w1.T).silu()
x_w3 = x.contiguous_backward() @ w3.T
return (x_w1 * x_w3) @ w2.T
@function(precompile=True, precompile_backward=True)
def run_layer(self, x:Tensor, freqs_cis:Tensor,
attention_norm:Tensor, wqkv:Tensor, wo:Tensor,
ffn_norm:Tensor, w1:Tensor, w2:Tensor, w3:Tensor):
h = x + self.attention(x, freqs_cis, attention_norm, wqkv, wo)
return h + self.feed_forward(h, ffn_norm, w1, w2, w3)
def __call__(self, tokens:Tensor):
h = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
for i in range(self.n_layers):
h = self.run_layer(h, freqs_cis,
self.attention_norm[i], self.wqkv[i], self.wo[i],
self.ffn_norm[i], self.w1[i], self.w2[i], self.w3[i])
logits = self.output(self.norm(h))
return logits
if __name__ == "__main__":
config = {}
BS = config["BS"] = getenv("BS", 16)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
from examples.llama3 import MODEL_PARAMS
model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers
model = FlatTransformer(**model_params, max_context=SEQLEN)
state = nn.state.get_state_dict(model)
print("tensor count:", len(state))
sz = 0
for k,v in state.items():
if v.requires_grad is None: v.requires_grad_(True)
print(f"{colored(k, 'green' if v.requires_grad else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device}")
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
with Timing("realize weights: "): Tensor.realize(*state.values())
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int).realize()
@TinyJit
def jit_step(tokens:Tensor):
GlobalCounters.reset()
print(colored("*** step", "red"))
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python backward: "): loss.backward()
with Timing("run step: "): loss.realize(*[x.grad for x in state.values() if x.requires_grad])
jit_step(tokens)
jit_step(tokens)
jit_step(tokens)
print(f"mem used: {GlobalCounters.mem_used/1e9:.2f} GB")