mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
examples of new GPT2 and JIT change (#2261)
* var_vals are global * working with global ish * better * fix export model * fix tests * better kv cache * does it run? * use where for kvmask * fix excessive var_vals * fix import * how does multigpu use this? * llama kinda work * faster and simpler * cleanup * fix conversation mode * test cleanups * fix one more test * test cleanup --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -300,10 +300,9 @@ cache_saved = CacheCollector.finish() # disable the cache
|
||||
|
||||
# there's one ASTRunner in the cache
|
||||
assert len(cache_saved) == 1
|
||||
prg, bufs, _ = cache_saved[0]
|
||||
|
||||
# print the C Program :)
|
||||
print(prg.prg)
|
||||
print(cache_saved[0].prg.prg)
|
||||
|
||||
# after some formatting (the compiler doesn't care)
|
||||
# NOTE: the 2 and 3 are constant folded
|
||||
|
||||
168
examples/gpt2.py
168
examples/gpt2.py
@@ -1,137 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# pip3 install tiktoken
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
np.set_printoptions(linewidth=200)
|
||||
from typing import Optional, Dict
|
||||
|
||||
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
import numpy as np
|
||||
from tinygrad.ops import Device
|
||||
from typing import Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn import Embedding, Linear, LayerNorm
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import TinyJit
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv
|
||||
|
||||
MAX_CONTEXT = 128
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
self.bias = Tensor.zeros(dim)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return (x.layernorm(eps=self.eps)) * self.weight + self.bias
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, linear=Linear):
|
||||
self.c_attn = linear(dim, 3*dim, bias=True)
|
||||
self.c_proj = linear(dim, dim, bias=True)
|
||||
def __init__(self, dim, n_heads):
|
||||
self.c_attn = Linear(dim, 3*dim, bias=True)
|
||||
self.c_proj = Linear(dim, dim, bias=True)
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]) -> Tensor:
|
||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
|
||||
if mask is not None:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
|
||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
bsz, seqlen, _, _ = xq.shape
|
||||
# kv caching!
|
||||
if start_pos == 0:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert cache_k, "no cache"
|
||||
#assert start_pos == cache_k.shape[1] and start_pos == cache_v.shape[1], "cache is wrong shape"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
|
||||
# save the cache
|
||||
cache_k, cache_v = keys.realize(), values.realize()
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
output = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.c_proj(output), cache_k, cache_v
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, linear=Linear):
|
||||
self.c_fc = linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = linear(hidden_dim, dim, bias=True)
|
||||
def __init__(self, dim, hidden_dim):
|
||||
self.c_fc = Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.c_proj(self.c_fc(x).gelu())
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, n_heads, norm_eps, linear=Linear):
|
||||
self.attn = Attention(dim, n_heads, linear)
|
||||
self.mlp = FeedForward(dim, 4*dim, linear)
|
||||
def __init__(self, dim, n_heads, norm_eps):
|
||||
self.attn = Attention(dim, n_heads)
|
||||
self.mlp = FeedForward(dim, 4*dim)
|
||||
self.ln_1 = LayerNorm(dim, norm_eps)
|
||||
self.ln_2 = LayerNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]):
|
||||
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask)
|
||||
h = x + output
|
||||
h = (h + self.mlp(self.ln_2(h)))
|
||||
return h, cache_k, cache_v
|
||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
|
||||
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
||||
return (h + self.mlp(self.ln_2(h)))
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_seq_len=1024):
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||
self.wte = Embedding(vocab_size, dim)
|
||||
self.wpe = Embedding(max_seq_len, dim)
|
||||
self.h = [TransformerBlock(dim, n_heads, norm_eps, linear) for _ in range(n_layers)]
|
||||
self.kv_caches = None
|
||||
self.n_layers = n_layers
|
||||
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
self.ln_f = LayerNorm(dim, norm_eps)
|
||||
self.lm_head = linear(dim, vocab_size, bias=False)
|
||||
self.lm_head = Linear(dim, vocab_size, bias=False)
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def embed(self, tokens:Tensor, pos:Tensor):
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(pos)
|
||||
h = tok_emb + pos_emb
|
||||
if getenv("FP16"): h = h.half()
|
||||
return h
|
||||
|
||||
def postprocess(self, x, temperature:Optional[float]):
|
||||
logits = self.lm_head(self.ln_f(x))
|
||||
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten()
|
||||
return logits
|
||||
|
||||
@TinyJit
|
||||
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, **kv_cache):
|
||||
h = self.embed(tokens, pos)
|
||||
|
||||
for i, hi in enumerate(self.h):
|
||||
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None)
|
||||
|
||||
# don't realize until here
|
||||
for v in kv_cache.values(): v.realize()
|
||||
return self.postprocess(h, temperature).realize(), kv_cache
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
_bsz, seqlen = tokens.shape
|
||||
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos)
|
||||
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
|
||||
for k,v in self.kv_caches.items():
|
||||
self.kv_caches[k] = v.reshape(v.shape[0], start_pos_var, v.shape[2], v.shape[3])
|
||||
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
|
||||
return logit_or_softmax
|
||||
else:
|
||||
if start_pos == 0:
|
||||
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
self.kv_caches = {**{f'cache_k{i}':None for i in range(self.n_layers)}, **{f'cache_v{i}':None for i in range(self.n_layers)}}
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos, start_pos+seqlen)))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float16 if getenv('FP16') else dtypes.float32).triu(start_pos+1).realize()
|
||||
h = self.embed(tokens, pos)
|
||||
for i, hi in enumerate(self.h):
|
||||
h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'] = hi(h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'], start_pos=start_pos, mask=mask)
|
||||
for v in self.kv_caches.values(): v.realize()
|
||||
return self.postprocess(h, temperature).realize()
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
|
||||
# **** files and arguments ****
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
MODEL_PARAMS = {
|
||||
@@ -141,19 +100,13 @@ MODEL_PARAMS = {
|
||||
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
|
||||
}
|
||||
|
||||
def get_url(model_size): return f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'
|
||||
|
||||
class GPT2:
|
||||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.utils import fetch_as_file
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
params = MODEL_PARAMS[model_size]
|
||||
model = Transformer(**params)
|
||||
weights = torch_load(fetch_as_file(get_url(model_size)))
|
||||
model = Transformer(**MODEL_PARAMS[model_size])
|
||||
weights = torch_load(fetch_as_file(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
||||
# special treatment for the Conv1D weights we need to transpose
|
||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
||||
for k in weights.keys():
|
||||
@@ -162,9 +115,6 @@ class GPT2:
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
|
||||
|
||||
if getenv("FP16"):
|
||||
for k, v in weights.items():
|
||||
weights[k] = v.cpu().half().realize()
|
||||
load_state_dict(model, weights)
|
||||
return GPT2(model, tokenizer)
|
||||
|
||||
@@ -183,7 +133,7 @@ class GPT2:
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
probs = self.model(Tensor([toks[start_pos:]]), start_pos, temperature)
|
||||
probs = self.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
|
||||
@@ -18,10 +18,11 @@ from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
MAX_CONTEXT = 1024
|
||||
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
@@ -69,27 +70,34 @@ class Attention:
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
bsz, seqlen, _ = x.shape
|
||||
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
if mask is not None:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# kv caching!
|
||||
if start_pos == 0:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
assert start_pos == (cache_k.shape[1].val if isinstance(cache_k.shape[1], Variable) else cache_k.shape[1]) == (cache_v.shape[1].val if isinstance(cache_v.shape[1], Variable) else cache_v.shape[1]), f"cache has wrong shape, {start_pos=}, {cache_k.shape[1]=}, {cache_v.shape[1]=}"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
|
||||
cache_k, cache_v = keys, values
|
||||
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
|
||||
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
|
||||
@@ -113,60 +121,31 @@ class TransformerBlock:
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None):
|
||||
bsz, seqlen, _ = x.shape
|
||||
if JIT and mask is None:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
pos = Variable("pos", 1, 1024).bind(start_pos)
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
|
||||
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
|
||||
h = x + output
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
|
||||
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.kv_caches = [(None, None) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
|
||||
self.norm_output = lambda x: self.output(self.norm(x))
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
|
||||
self.postprocess_jitted = TinyJit(self.postprocess)
|
||||
self.layers_jitted = [TinyJit(layer.__call__) for layer in self.layers]
|
||||
|
||||
def postprocess(self, x, temperature:Optional[float]):
|
||||
logits = self.output(self.norm(x))
|
||||
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
return logits.realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
_bsz, seqlen = tokens.shape
|
||||
if seqlen == 1 and start_pos > 0 and JIT:
|
||||
pos = Variable("pos", 1, 1024).bind(start_pos)
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
h = self.tok_embeddings_jitted(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos.unbind()[0]: start_pos})
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess_jitted(h, temperature)
|
||||
else:
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize()
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers, self.kv_caches)):
|
||||
# need this reshape back to int shape in conversational mode because jitted and unjitted calls share the same cache
|
||||
if cache_k is not None and start_pos > 0:
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess(h, temperature)
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
|
||||
# **** files and arguments ****
|
||||
MODEL_PARAMS = {
|
||||
@@ -522,7 +501,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: output logits are not used
|
||||
llama.model(Tensor([toks]), Variable("start_pos", 0, MAX_CONTEXT).bind(0), args.temperature).realize() # NOTE: output logits are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
@@ -561,7 +540,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), args.temperature).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args,var_vals in run.jit_cache:
|
||||
assert not var_vals, "symbolic shape is not supported"
|
||||
for ji in run.jit_cache:
|
||||
fxn = ji.prg
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(args):
|
||||
for i,arg in enumerate(ji.rawbufs):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
@@ -43,7 +43,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
# hack to put the inputs back
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
realized_input = args[idx[0]].lazydata.realized
|
||||
run.jit_cache[j][1][i] = realized_input
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx[0]}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
|
||||
@@ -63,12 +63,13 @@ def compile(dat, output_fn):
|
||||
|
||||
# pull out inputs and put them in the jit cache
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
|
||||
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j][1][i] = input_rawbuffers[idx]
|
||||
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
|
||||
|
||||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
cl_cache = []
|
||||
for prg,args,_ in model_exec.jit_cache:
|
||||
for ji in model_exec.jit_cache:
|
||||
prg = ji.prg
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
@@ -79,7 +80,7 @@ def compile(dat, output_fn):
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in ji.rawbufs]]))
|
||||
used_ops += prg.op_estimate
|
||||
|
||||
from extra.thneed import Thneed
|
||||
|
||||
9
test/external/external_test_opt.py
vendored
9
test/external/external_test_opt.py
vendored
@@ -18,14 +18,14 @@ from tinygrad.lazy import PUSH_PERMUTES
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True): self.allowed, self.strict, self.preclear = allowed, strict, preclear
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
GlobalCounters.reset()
|
||||
CacheCollector.start()
|
||||
CacheCollector.start(self.var_vals)
|
||||
print("cache: entering")
|
||||
def __exit__(self, type, value, traceback):
|
||||
cache = CacheCollector.finish()
|
||||
@@ -85,11 +85,12 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(85):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
with CLCache(98, var_vals={Variable("start_pos", 0, 1024): 0}):
|
||||
model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 1024).bind(0)).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
|
||||
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 126 if CI else 486, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
def test_gpt2(self):
|
||||
@@ -68,7 +68,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 129 if CI else 369, all_jitted=True)
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -65,9 +65,9 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
de.realize()
|
||||
cache = CacheCollector.finish()
|
||||
assert len(cache) == 3
|
||||
assert cache[0][0].name.startswith("r_") # Reduce should not merged 2 times.
|
||||
assert cache[1][0].name.startswith("E_")
|
||||
assert cache[2][0].name.startswith("E_")
|
||||
assert cache[0].prg.name.startswith("r_") # Reduce should not merged 2 times.
|
||||
assert cache[1].prg.name.startswith("E_")
|
||||
assert cache[2].prg.name.startswith("E_")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
CacheCollector.start()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize()
|
||||
rawbufs = CacheCollector.finish()[0][1]
|
||||
rawbufs = CacheCollector.finish()[0].rawbufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -19,19 +19,6 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_reshape_inside_plus1(self):
|
||||
def f(a, jit=False, jit_ctx=None):
|
||||
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
|
||||
return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
|
||||
@@ -1,75 +1,82 @@
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
from collections import defaultdict
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.ops import RawBuffer, Device
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from dataclasses import dataclass
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[Tuple[Any, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
||||
self.jit_cache: List[JitItem] = []
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]] = {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
|
||||
self.updatable_entries: Dict[int, List[int]] = defaultdict(list) # (kernel_number) -> list(argument id). These are buffers from input + variables.
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
|
||||
# all inputs are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
|
||||
# get rawbuffers
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {k:(cast(RawBuffer, v.lazydata.realized), v.lazydata.st) for k,v in input_tensors.items()}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
|
||||
if self.cnt >= 2:
|
||||
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
|
||||
# check validity and assign the inputs
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
|
||||
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||
for j in self.updatable_entries.keys():
|
||||
for k in self.jit_cache[j][2].keys():
|
||||
try: self.jit_cache[j][2][k] = var_vals[k]
|
||||
except KeyError: pass
|
||||
for prg, pargs, variables in self.jit_cache: prg(pargs, variables, jit=True)
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
|
||||
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name][0]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
elif self.cnt == 1:
|
||||
CacheCollector.start()
|
||||
CacheCollector.start(var_vals)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# get the inputs for replacement
|
||||
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
|
||||
for i,a in enumerate(cache[1]):
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v[0] for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j_,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
self.updatable_entries[j_].append(i)
|
||||
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
|
||||
self.input_replace[(j,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 0:
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
# clear the inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
self.cnt += 1
|
||||
return self.ret
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
|
||||
def start(self): self.cache = []
|
||||
def __init__(self):
|
||||
self.cache: Optional[List[JitItem]] = None
|
||||
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.cache = []
|
||||
self.var_vals = var_vals if var_vals is not None else {}
|
||||
def add(self, prg, rawbufs, var_vals):
|
||||
if self.cache is None: return
|
||||
self.cache.append((prg, rawbufs, var_vals))
|
||||
def finish(self):
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
self.cache.append(JitItem(prg, rawbufs))
|
||||
def finish(self) -> List[JitItem]:
|
||||
if self.cache is None: return []
|
||||
ret = self.cache
|
||||
self.cache = None
|
||||
|
||||
@@ -188,6 +188,7 @@ class ASTRunner:
|
||||
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
self.vars:List[Variable] = []
|
||||
|
||||
def build(self, compiler, runtime):
|
||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||
@@ -258,14 +259,13 @@ class Compiled:
|
||||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
# extract real vars used in ast
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
ast_vars = vars_from_ast(ast)
|
||||
assert all(v.val is None for v in ast_vars), f"ast contains bound Variable {ast_vars}"
|
||||
|
||||
# compilation time
|
||||
def get_program():
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if not NOOPT:
|
||||
@@ -286,7 +286,11 @@ class Compiled:
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
return self.to_program(k)
|
||||
prg = self.to_program(k)
|
||||
# extract real vars used in ast
|
||||
prg.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
|
||||
return prg
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_program()
|
||||
@@ -296,5 +300,5 @@ class Compiled:
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in ast_vars})
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in prg.vars})
|
||||
return output.realized
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
|
||||
from tinygrad.graph import log_schedule_item, print_tree
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
|
||||
|
||||
@@ -18,7 +18,6 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
si = schedule.pop(0)
|
||||
if not disable_logging: log_schedule_item(si)
|
||||
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
|
||||
if DEBUG >= 3: print_tree(si.ast)
|
||||
if si.ast.op in LoadOps:
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
|
||||
@@ -35,6 +35,12 @@ class Node:
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
ret = self.substitute({x:NumNode(x.val) for x in self.vars()})
|
||||
assert isinstance(ret, NumNode), f"val must be NumNode, it's {ret}"
|
||||
return ret.b
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
@@ -150,10 +156,14 @@ class Variable(Node):
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self.val:Optional[int] = None
|
||||
self._val: Optional[int] = None
|
||||
@property
|
||||
def val(self):
|
||||
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
|
||||
return self._val
|
||||
def bind(self, val):
|
||||
assert self.val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self.val = val
|
||||
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
||||
self._val = val
|
||||
return self
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
@@ -261,8 +271,11 @@ class SumNode(RedNode):
|
||||
if x.b%b == 0: fully_divided.append(x//b)
|
||||
else:
|
||||
rest.append(x)
|
||||
if isinstance(x.b, int):
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
||||
else:
|
||||
_gcd = 1
|
||||
else:
|
||||
rest.append(x)
|
||||
_gcd = 1
|
||||
@@ -341,7 +354,7 @@ sint = Union[Node, int]
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self.val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
|
||||
@@ -253,7 +253,7 @@ class Tensor:
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x else (0,s) for x,s in zip(arg, self.shape))) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
|
||||
|
||||
Reference in New Issue
Block a user