examples of new GPT2 and JIT change (#2261)

* var_vals are global

* working with global ish

* better

* fix export model

* fix tests

* better kv cache

* does it run?

* use where for kvmask

* fix excessive var_vals

* fix import

* how does multigpu use this?

* llama kinda work

* faster and simpler

* cleanup

* fix conversation mode

* test cleanups

* fix one more test

* test cleanup

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
chenyu
2023-11-10 15:07:02 -05:00
committed by GitHub
parent b6aaf12df7
commit a753c8e071
15 changed files with 189 additions and 249 deletions

View File

@@ -300,10 +300,9 @@ cache_saved = CacheCollector.finish() # disable the cache
# there's one ASTRunner in the cache
assert len(cache_saved) == 1
prg, bufs, _ = cache_saved[0]
# print the C Program :)
print(prg.prg)
print(cache_saved[0].prg.prg)
# after some formatting (the compiler doesn't care)
# NOTE: the 2 and 3 are constant folded

View File

@@ -1,137 +1,96 @@
#!/usr/bin/env python3
# pip3 install tiktoken
import argparse
import numpy as np
from tqdm import trange
np.set_printoptions(linewidth=200)
from typing import Optional, Dict
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
from tinygrad.helpers import GlobalCounters
import numpy as np
from tinygrad.ops import Device
from typing import Optional
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.jit import TinyJit
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict
from extra.utils import fetch_as_file
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv
MAX_CONTEXT = 128
class LayerNorm:
def __init__(self, dim, eps=1e-5):
self.eps = eps
self.weight = Tensor.ones(dim)
self.bias = Tensor.zeros(dim)
def __call__(self, x:Tensor):
return (x.layernorm(eps=self.eps)) * self.weight + self.bias
class Attention:
def __init__(self, dim, n_heads, linear=Linear):
self.c_attn = linear(dim, 3*dim, bias=True)
self.c_proj = linear(dim, dim, bias=True)
def __init__(self, dim, n_heads):
self.c_attn = Linear(dim, 3*dim, bias=True)
self.c_proj = Linear(dim, dim, bias=True)
self.n_heads = n_heads
self.dim = dim
self.head_dim = dim // n_heads
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]) -> Tensor:
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
if mask is not None:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
bsz, seqlen, n_heads, head_dim = xq.shape
bsz, seqlen, _, _ = xq.shape
# kv caching!
if start_pos == 0:
keys, values = xk, xv
else:
assert cache_k, "no cache"
#assert start_pos == cache_k.shape[1] and start_pos == cache_v.shape[1], "cache is wrong shape"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
# save the cache
cache_k, cache_v = keys.realize(), values.realize()
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
output = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.c_proj(output), cache_k, cache_v
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
class FeedForward:
def __init__(self, dim, hidden_dim, linear=Linear):
self.c_fc = linear(dim, hidden_dim, bias=True)
self.c_proj = linear(hidden_dim, dim, bias=True)
def __init__(self, dim, hidden_dim):
self.c_fc = Linear(dim, hidden_dim, bias=True)
self.c_proj = Linear(hidden_dim, dim, bias=True)
def __call__(self, x:Tensor) -> Tensor:
return self.c_proj(self.c_fc(x).gelu())
class TransformerBlock:
def __init__(self, dim, n_heads, norm_eps, linear=Linear):
self.attn = Attention(dim, n_heads, linear)
self.mlp = FeedForward(dim, 4*dim, linear)
def __init__(self, dim, n_heads, norm_eps):
self.attn = Attention(dim, n_heads)
self.mlp = FeedForward(dim, 4*dim)
self.ln_1 = LayerNorm(dim, norm_eps)
self.ln_2 = LayerNorm(dim, norm_eps)
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]):
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask)
h = x + output
h = (h + self.mlp(self.ln_2(h)))
return h, cache_k, cache_v
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
h = x + self.attn(self.ln_1(x), start_pos, mask)
return (h + self.mlp(self.ln_2(h)))
class Transformer:
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_seq_len=1024):
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
self.wte = Embedding(vocab_size, dim)
self.wpe = Embedding(max_seq_len, dim)
self.h = [TransformerBlock(dim, n_heads, norm_eps, linear) for _ in range(n_layers)]
self.kv_caches = None
self.n_layers = n_layers
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
self.ln_f = LayerNorm(dim, norm_eps)
self.lm_head = linear(dim, vocab_size, bias=False)
self.lm_head = Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def embed(self, tokens:Tensor, pos:Tensor):
tok_emb = self.wte(tokens)
pos_emb = self.wpe(pos)
h = tok_emb + pos_emb
if getenv("FP16"): h = h.half()
return h
def postprocess(self, x, temperature:Optional[float]):
logits = self.lm_head(self.ln_f(x))
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten()
return logits
@TinyJit
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, **kv_cache):
h = self.embed(tokens, pos)
for i, hi in enumerate(self.h):
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None)
# don't realize until here
for v in kv_cache.values(): v.realize()
return self.postprocess(h, temperature).realize(), kv_cache
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
_bsz, seqlen = tokens.shape
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos)
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
for k,v in self.kv_caches.items():
self.kv_caches[k] = v.reshape(v.shape[0], start_pos_var, v.shape[2], v.shape[3])
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
return logit_or_softmax
else:
if start_pos == 0:
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
self.kv_caches = {**{f'cache_k{i}':None for i in range(self.n_layers)}, **{f'cache_v{i}':None for i in range(self.n_layers)}}
_bsz, seqlen = tokens.shape
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos, start_pos+seqlen)))
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float16 if getenv('FP16') else dtypes.float32).triu(start_pos+1).realize()
h = self.embed(tokens, pos)
for i, hi in enumerate(self.h):
h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'] = hi(h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'], start_pos=start_pos, mask=mask)
for v in self.kv_caches.values(): v.realize()
return self.postprocess(h, temperature).realize()
tok_emb = self.wte(tokens)
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
h = tok_emb + pos_emb
# **** files and arguments ****
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
logits = self.lm_head(self.ln_f(h))
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
MODEL_PARAMS = {
@@ -141,19 +100,13 @@ MODEL_PARAMS = {
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
}
def get_url(model_size): return f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'
class GPT2:
@staticmethod
def build(model_size="gpt2"):
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.utils import fetch_as_file
tokenizer = tiktoken.get_encoding("gpt2")
params = MODEL_PARAMS[model_size]
model = Transformer(**params)
weights = torch_load(fetch_as_file(get_url(model_size)))
model = Transformer(**MODEL_PARAMS[model_size])
weights = torch_load(fetch_as_file(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
# special treatment for the Conv1D weights we need to transpose
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
for k in weights.keys():
@@ -162,9 +115,6 @@ class GPT2:
# lm head and wte are tied
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
if getenv("FP16"):
for k, v in weights.items():
weights[k] = v.cpu().half().realize()
load_state_dict(model, weights)
return GPT2(model, tokenizer)
@@ -183,7 +133,7 @@ class GPT2:
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
probs = self.model(Tensor([toks[start_pos:]]), start_pos, temperature)
probs = self.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)

View File

@@ -18,10 +18,11 @@ from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable, sym_infer
MAX_CONTEXT = 1024
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
@@ -69,27 +70,34 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]:
bsz, seqlen, _ = x.shape
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
if mask is not None:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
bsz, seqlen, n_heads, head_dim = xq.shape
# kv caching!
if start_pos == 0:
keys, values = xk, xv
else:
assert cache_k is not None and cache_v is not None, "no cache"
assert start_pos == (cache_k.shape[1].val if isinstance(cache_k.shape[1], Variable) else cache_k.shape[1]) == (cache_v.shape[1].val if isinstance(cache_v.shape[1], Variable) else cache_v.shape[1]), f"cache has wrong shape, {start_pos=}, {cache_k.shape[1]=}, {cache_v.shape[1]=}"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
cache_k, cache_v = keys, values
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
@@ -113,60 +121,31 @@ class TransformerBlock:
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None):
bsz, seqlen, _ = x.shape
if JIT and mask is None:
assert cache_k is not None and cache_v is not None, "no cache"
pos = Variable("pos", 1, 1024).bind(start_pos)
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
h = x + output
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
self.kv_caches = [(None, None) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
self.norm_output = lambda x: self.output(self.norm(x))
self.forward_jit = TinyJit(self.forward)
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
self.postprocess_jitted = TinyJit(self.postprocess)
self.layers_jitted = [TinyJit(layer.__call__) for layer in self.layers]
def postprocess(self, x, temperature:Optional[float]):
logits = self.output(self.norm(x))
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
return logits.realize()
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
_bsz, seqlen = tokens.shape
if seqlen == 1 and start_pos > 0 and JIT:
pos = Variable("pos", 1, 1024).bind(start_pos)
# get only the part of freqs_cis that we are using.
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
h = self.tok_embeddings_jitted(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos.unbind()[0]: start_pos})
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess_jitted(h, temperature)
else:
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize()
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos.val+1).realize() if seqlen > 1 else None
h = self.tok_embeddings(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers, self.kv_caches)):
# need this reshape back to int shape in conversational mode because jitted and unjitted calls share the same cache
if cache_k is not None and start_pos > 0:
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess(h, temperature)
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
# **** files and arguments ****
MODEL_PARAMS = {
@@ -522,7 +501,7 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
with Timing():
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: output logits are not used
llama.model(Tensor([toks]), Variable("start_pos", 0, MAX_CONTEXT).bind(0), args.temperature).realize() # NOTE: output logits are not used
start_pos = len(toks)
else:
# non chat bot mode
@@ -561,7 +540,7 @@ After you are done speaking, output [EOS]. You are not Chad.
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs = llama.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), args.temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))

View File

@@ -9,11 +9,11 @@ EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for fxn,args,var_vals in run.jit_cache:
assert not var_vals, "symbolic shape is not supported"
for ji in run.jit_cache:
fxn = ji.prg
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(args):
for i,arg in enumerate(ji.rawbufs):
key = id(arg)
if key not in bufs:
if key in special_names:
@@ -43,7 +43,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx[0]].lazydata.realized
run.jit_cache[j][1][i] = realized_input
run.jit_cache[j].rawbufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx[0]}'
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)

View File

@@ -63,12 +63,13 @@ def compile(dat, output_fn):
# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j][1][i] = input_rawbuffers[idx]
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
# transform to CL.CACHE
used_ops = 0
cl_cache = []
for prg,args,_ in model_exec.jit_cache:
for ji in model_exec.jit_cache:
prg = ji.prg
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
@@ -79,7 +80,7 @@ def compile(dat, output_fn):
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in ji.rawbufs]]))
used_ops += prg.op_estimate
from extra.thneed import Thneed

View File

@@ -18,14 +18,14 @@ from tinygrad.lazy import PUSH_PERMUTES
from tinygrad.jit import CacheCollector
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True): self.allowed, self.strict, self.preclear = allowed, strict, preclear
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
def __enter__(self):
if self.preclear:
gc.collect()
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
x.realize()
GlobalCounters.reset()
CacheCollector.start()
CacheCollector.start(self.var_vals)
print("cache: entering")
def __exit__(self, type, value, traceback):
cache = CacheCollector.finish()
@@ -85,11 +85,12 @@ class TestInferenceMinKernels(unittest.TestCase):
def test_llama(self):
from examples.llama import Transformer
from tinygrad.shape.symbolic import Variable
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(85):
model(Tensor([[1,2,3,4]]), 0).realize()
with CLCache(98, var_vals={Variable("start_pos", 0, 1024): 0}):
model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 1024).bind(0)).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):

View File

@@ -57,7 +57,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t): return model(t, 0).realize()
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 126 if CI else 486, all_jitted=True)
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM")
def test_gpt2(self):
@@ -68,7 +68,7 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t): return model(t, 0).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 129 if CI else 369, all_jitted=True)
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
def test_train_cifar(self):

View File

@@ -65,9 +65,9 @@ class TestLazyBuffer(unittest.TestCase):
de.realize()
cache = CacheCollector.finish()
assert len(cache) == 3
assert cache[0][0].name.startswith("r_") # Reduce should not merged 2 times.
assert cache[1][0].name.startswith("E_")
assert cache[2][0].name.startswith("E_")
assert cache[0].prg.name.startswith("r_") # Reduce should not merged 2 times.
assert cache[1].prg.name.startswith("E_")
assert cache[2].prg.name.startswith("E_")
if __name__ == "__main__":
unittest.main()

View File

@@ -17,7 +17,7 @@ class TestLinearizer(unittest.TestCase):
np_a, np_b = a.numpy(), b.numpy()
CacheCollector.start()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize()
rawbufs = CacheCollector.finish()[0][1]
rawbufs = CacheCollector.finish()[0].rawbufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)

View File

@@ -19,19 +19,6 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_reshape_inside_plus1(self):
def f(a, jit=False, jit_ctx=None):
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, i)
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
expected = f(a).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)

View File

@@ -1,75 +1,82 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
from collections import defaultdict
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device
from tinygrad.ops import RawBuffer, Device, ASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from dataclasses import dataclass
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.cnt: int = 0
self.jit_cache: List[Tuple[Any, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
self.jit_cache: List[JitItem] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]] = {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
self.updatable_entries: Dict[int, List[int]] = defaultdict(list) # (kernel_number) -> list(argument id). These are buffers from input + variables.
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
# get rawbuffers
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {k:(cast(RawBuffer, v.lazydata.realized), v.lazydata.st) for k,v in input_tensors.items()}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
if self.cnt >= 2:
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
# check validity and assign the inputs
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for j in self.updatable_entries.keys():
for k in self.jit_cache[j][2].keys():
try: self.jit_cache[j][2][k] = var_vals[k]
except KeyError: pass
for prg, pargs, variables in self.jit_cache: prg(pargs, variables, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name][0]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
elif self.cnt == 1:
CacheCollector.start()
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
for i,a in enumerate(cache[1]):
for j,ji in enumerate(self.jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
self.updatable_entries[j_].append(i)
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
self.input_replace[(j,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
# clear the inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return self.ret
class _CacheCollector:
def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
def start(self): self.cache = []
def __init__(self):
self.cache: Optional[List[JitItem]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
self.cache.append((prg, rawbufs, var_vals))
def finish(self):
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.cache.append(JitItem(prg, rawbufs))
def finish(self) -> List[JitItem]:
if self.cache is None: return []
ret = self.cache
self.cache = None

View File

@@ -188,6 +188,7 @@ class ASTRunner:
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
self.vars:List[Variable] = []
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
@@ -258,14 +259,13 @@ class Compiled:
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
# extract real vars used in ast
from tinygrad.lazy import vars_from_ast
ast_vars = vars_from_ast(ast)
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
# compilation time
def get_program():
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not NOOPT:
@@ -286,7 +286,11 @@ class Compiled:
k = timed[0][1]
else:
k.required_optimizations()
return self.to_program(k)
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg
if getenv("ENABLE_METHOD_CACHE", 1):
if ast not in self.method_cache: self.method_cache[ast] = get_program()
@@ -296,5 +300,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in prg.vars})
return output.realized

View File

@@ -1,7 +1,7 @@
from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.graph import log_schedule_item
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
@@ -18,7 +18,6 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
si = schedule.pop(0)
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if DEBUG >= 3: print_tree(si.ast)
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"

View File

@@ -35,6 +35,12 @@ class Node:
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@property
def val(self):
ret = self.substitute({x:NumNode(x.val) for x in self.vars()})
assert isinstance(ret, NumNode), f"val must be NumNode, it's {ret}"
return ret.b
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
@@ -150,10 +156,14 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self.val:Optional[int] = None
self._val: Optional[int] = None
@property
def val(self):
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
return self._val
def bind(self, val):
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self.val = val
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
self._val = val
return self
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
@@ -261,8 +271,11 @@ class SumNode(RedNode):
if x.b%b == 0: fully_divided.append(x//b)
else:
rest.append(x)
if isinstance(x.b, int):
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
else:
_gcd = 1
else:
rest.append(x)
_gcd = 1
@@ -341,7 +354,7 @@ sint = Union[Node, int]
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",

View File

@@ -253,7 +253,7 @@ class Tensor:
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x else (0,s) for x,s in zip(arg, self.shape))) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)