mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
llama JIT python runtime speedup (#1633)
* no JIT call in TransformerBlock * idea * move 2 reshapes to jitted function shrink inside jitted too, 6.3ms remove back reshapes, 5.5ms isinstance -> __class__ 4.99ms * think revert ops_gpu.py revert symbolic.py too PYOPENCL_COMPILER_OUTPUT=1 * cleanup * fix cache shape for conversational model only reshape if start_pos > 0 * small cleanup * include var_vals.keys() to st.key * add comments * llama small update * everything jitted again, similar structure to gpt2 * fix typing * add TODO for in place update cache
This commit is contained in:
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import functools, sys, argparse, json, os
|
||||
import numpy as np
|
||||
np.set_printoptions(linewidth=200)
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
from tinygrad.helpers import Timing, getenv, DEBUG, dtypes
|
||||
from tinygrad.ops import Device
|
||||
@@ -67,7 +67,7 @@ class Attention:
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Tensor, cache_v:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
@@ -79,13 +79,13 @@ class Attention:
|
||||
if start_pos == 0:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert cache_k.shape[0] > 0, "no cache"
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
|
||||
cache_k, cache_v = keys, values
|
||||
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
|
||||
|
||||
@@ -110,62 +110,70 @@ class TransformerBlock:
|
||||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
self.cache_k, self.cache_v = None, None
|
||||
|
||||
self.jitted_attention_norm = TinyJit(lambda x: self.attention_norm(x).realize())
|
||||
self.jitted_attn = TinyJit(self.attention.__call__)
|
||||
self.jitted_norm_output = TinyJit(self.norm_output)
|
||||
|
||||
def norm_output(self, x:Tensor, output:Tensor) -> Tensor:
|
||||
h = x + output
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None):
|
||||
bsz, seqlen, _ = x.shape
|
||||
do_jit = getenv("JIT") and mask is None
|
||||
if do_jit:
|
||||
if getenv("JIT") and mask is None:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
pos = Variable("pos", 1, 1024)
|
||||
self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3])
|
||||
self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3])
|
||||
output, cache_k, cache_v = self.jitted_attn(self.jitted_attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
|
||||
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
|
||||
cache_k.lazydata.st.var_vals[pos] = start_pos
|
||||
cache_v.lazydata.st.var_vals[pos] = start_pos
|
||||
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.st.var_vals[pos] = start_pos
|
||||
else:
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
|
||||
# save the cache. with symbolic shape, cast it back to int shape so we have int shape in cache
|
||||
self.cache_k = cache_k.reshape(cache_k.shape[0], start_pos+seqlen, cache_k.shape[2], cache_k.shape[3]).realize()
|
||||
self.cache_v = cache_v.reshape(cache_v.shape[0], start_pos+seqlen, cache_v.shape[2], cache_v.shape[3]).realize()
|
||||
|
||||
return self.jitted_norm_output(x, output) if do_jit else self.norm_output(x, output)
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
|
||||
h = x + output
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.kv_caches = [(None, None) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
|
||||
self.norm_output = lambda x: self.output(self.norm(x))
|
||||
|
||||
self.jitted_tok_embeddings = TinyJit(lambda x: self.tok_embeddings(x).realize())
|
||||
self.jitted_norm_output = TinyJit(lambda x: self.norm_output(x).realize())
|
||||
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
|
||||
self.postprocess_jitted = TinyJit(self.postprocess)
|
||||
self.layers_jitted = [TinyJit(layer.__call__) for layer in self.layers]
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int):
|
||||
def postprocess(self, x, temperature:Optional[float]):
|
||||
logits = self.output(self.norm(x))
|
||||
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
return logits.realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
_bsz, seqlen = tokens.shape
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
do_jit = getenv("JIT") and mask is None
|
||||
|
||||
# get only the part of freqs_cis that we are using.
|
||||
if do_jit:
|
||||
if seqlen == 1 and getenv("JIT"):
|
||||
pos = Variable("pos", 1, 1024)
|
||||
assert seqlen == 1, "seqlen > 1 not supported for JIT"
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.st.var_vals[pos] = start_pos
|
||||
h = self.tok_embeddings_jitted(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})
|
||||
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess_jitted(h, temperature)
|
||||
else:
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
|
||||
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
|
||||
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
|
||||
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize()
|
||||
h = self.tok_embeddings(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers, self.kv_caches)):
|
||||
# need this reshape back to int shape in conversational mode because jitted and unjitted calls share the same cache
|
||||
if cache_k is not None and start_pos > 0:
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=mask)
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess(h, temperature)
|
||||
|
||||
# **** files and arguments ****
|
||||
|
||||
@@ -206,15 +214,6 @@ MODEL_PARAMS = {
|
||||
}
|
||||
|
||||
# **** helper functions ****
|
||||
def sample(logits, temperature):
|
||||
if temperature < 1e-6:
|
||||
# so close to 0 we use argmax
|
||||
return int(logits.argmax().numpy())
|
||||
else:
|
||||
probs = (logits / temperature).softmax()
|
||||
probs = probs.numpy().flatten()
|
||||
return int(np.random.choice(len(probs), p=probs))
|
||||
|
||||
def concat_weights(models):
|
||||
def convert(name) -> Tensor:
|
||||
disk_tensors = [model[name] for model in models]
|
||||
@@ -300,8 +299,9 @@ class LLaMa:
|
||||
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
||||
start_pos = 0
|
||||
for i in range(max_length):
|
||||
logits = self.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :]
|
||||
tok = sample(logits, temperature)
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
|
||||
@@ -439,7 +439,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
llama.model(Tensor([toks]), 0).realize() # NOTE: output logits are not used
|
||||
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: output logits are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
@@ -474,10 +474,13 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
|
||||
if args.timing: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
|
||||
logits = llama.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :]
|
||||
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
with Timing("sync in ", enabled=args.timing):
|
||||
tok = sample(logits, args.temperature)
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
|
||||
# use the kv cache
|
||||
start_pos = len(toks)
|
||||
|
||||
@@ -19,6 +19,19 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_reshape_inside_plus1(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
def f(a, jit=False, jit_ctx=None):
|
||||
if jit: a = a.reshape(3, vi)
|
||||
return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
|
||||
@@ -24,16 +24,22 @@ class TinyJit:
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key)) # type: ignore
|
||||
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
|
||||
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
|
||||
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
|
||||
for k in variables.keys():
|
||||
try: variables[k] = var_vals[k]
|
||||
except KeyError: pass
|
||||
prg(pargs, variables, jit=True)
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 1:
|
||||
|
||||
@@ -87,9 +87,14 @@ class CLProgram:
|
||||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs))
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
|
||||
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
|
||||
cl_bufs, wait_for = [], []
|
||||
for x in bufs:
|
||||
if x.__class__ is CLBuffer:
|
||||
cl_bufs.append(x._buf)
|
||||
if hasattr(x, "event"): wait_for.append(x.event)
|
||||
else: cl_bufs.append(x)
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
|
||||
if wait:
|
||||
e.wait()
|
||||
try:
|
||||
|
||||
@@ -139,7 +139,7 @@ class ShapeTracker:
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys()))
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
|
||||
@@ -283,10 +283,11 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
return create_node(ret)
|
||||
|
||||
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(n, (int, NumNode)): return int(n)
|
||||
if isinstance(n, Variable): return var_vals[n]
|
||||
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
|
||||
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
|
||||
if n.__class__ is int: return n # type: ignore
|
||||
if n.__class__ is NumNode: return n.b # type: ignore
|
||||
if n.__class__ is Variable: return var_vals[n] # type: ignore
|
||||
if n.__class__ is MulNode: return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals) # type: ignore
|
||||
if n.__class__ is SumNode: return sum([sym_infer(s, var_vals) for s in n.nodes]) # type: ignore
|
||||
raise NotImplementedError(n)
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
|
||||
Reference in New Issue
Block a user