llama JIT python runtime speedup (#1633)

* no JIT call in TransformerBlock

* idea

* move 2 reshapes to jitted function

shrink inside jitted too, 6.3ms

remove back reshapes, 5.5ms

isinstance -> __class__ 4.99ms

* think

revert ops_gpu.py

revert symbolic.py too

PYOPENCL_COMPILER_OUTPUT=1

* cleanup

* fix cache shape for conversational model

only reshape if start_pos > 0

* small cleanup

* include var_vals.keys() to st.key

* add comments

* llama small update

* everything jitted again, similar structure to gpt2

* fix typing

* add TODO for in place update cache
This commit is contained in:
chenyu
2023-08-30 07:51:05 -07:00
committed by GitHub
parent 1682e9a38a
commit ac183568be
6 changed files with 94 additions and 66 deletions

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import functools, sys, argparse, json, os
import numpy as np
np.set_printoptions(linewidth=200)
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from tinygrad.helpers import Timing, getenv, DEBUG, dtypes
from tinygrad.ops import Device
@@ -67,7 +67,7 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, cache_k:Tensor, cache_v:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]:
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
@@ -79,13 +79,13 @@ class Attention:
if start_pos == 0:
keys, values = xk, xv
else:
assert cache_k.shape[0] > 0, "no cache"
assert cache_k is not None and cache_v is not None, "no cache"
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
cache_k, cache_v = keys, values
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
@@ -110,62 +110,70 @@ class TransformerBlock:
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
self.cache_k, self.cache_v = None, None
self.jitted_attention_norm = TinyJit(lambda x: self.attention_norm(x).realize())
self.jitted_attn = TinyJit(self.attention.__call__)
self.jitted_norm_output = TinyJit(self.norm_output)
def norm_output(self, x:Tensor, output:Tensor) -> Tensor:
h = x + output
return (h + self.feed_forward(self.ffn_norm(h))).realize()
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None):
bsz, seqlen, _ = x.shape
do_jit = getenv("JIT") and mask is None
if do_jit:
if getenv("JIT") and mask is None:
assert cache_k is not None and cache_v is not None, "no cache"
pos = Variable("pos", 1, 1024)
self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3])
self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3])
output, cache_k, cache_v = self.jitted_attn(self.jitted_attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
cache_k.lazydata.st.var_vals[pos] = start_pos
cache_v.lazydata.st.var_vals[pos] = start_pos
# get only the part of freqs_cis that we are using.
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
freqs_cis.lazydata.st.var_vals[pos] = start_pos
else:
output, cache_k, cache_v = self.attention(self.attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
# save the cache. with symbolic shape, cast it back to int shape so we have int shape in cache
self.cache_k = cache_k.reshape(cache_k.shape[0], start_pos+seqlen, cache_k.shape[2], cache_k.shape[3]).realize()
self.cache_v = cache_v.reshape(cache_v.shape[0], start_pos+seqlen, cache_v.shape[2], cache_v.shape[3]).realize()
return self.jitted_norm_output(x, output) if do_jit else self.norm_output(x, output)
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
h = x + output
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
self.kv_caches = [(None, None) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
self.norm_output = lambda x: self.output(self.norm(x))
self.jitted_tok_embeddings = TinyJit(lambda x: self.tok_embeddings(x).realize())
self.jitted_norm_output = TinyJit(lambda x: self.norm_output(x).realize())
self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize())
self.postprocess_jitted = TinyJit(self.postprocess)
self.layers_jitted = [TinyJit(layer.__call__) for layer in self.layers]
def __call__(self, tokens:Tensor, start_pos:int):
def postprocess(self, x, temperature:Optional[float]):
logits = self.output(self.norm(x))
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
return logits.realize()
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
_bsz, seqlen = tokens.shape
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
do_jit = getenv("JIT") and mask is None
# get only the part of freqs_cis that we are using.
if do_jit:
if seqlen == 1 and getenv("JIT"):
pos = Variable("pos", 1, 1024)
assert seqlen == 1, "seqlen > 1 not supported for JIT"
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
freqs_cis.lazydata.st.var_vals[pos] = start_pos
h = self.tok_embeddings_jitted(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess_jitted(h, temperature)
else:
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize()
h = self.tok_embeddings(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers, self.kv_caches)):
# need this reshape back to int shape in conversational mode because jitted and unjitted calls share the same cache
if cache_k is not None and start_pos > 0:
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=mask)
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess(h, temperature)
# **** files and arguments ****
@@ -206,15 +214,6 @@ MODEL_PARAMS = {
}
# **** helper functions ****
def sample(logits, temperature):
if temperature < 1e-6:
# so close to 0 we use argmax
return int(logits.argmax().numpy())
else:
probs = (logits / temperature).softmax()
probs = probs.numpy().flatten()
return int(np.random.choice(len(probs), p=probs))
def concat_weights(models):
def convert(name) -> Tensor:
disk_tensors = [model[name] for model in models]
@@ -300,8 +299,9 @@ class LLaMa:
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
start_pos = 0
for i in range(max_length):
logits = self.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :]
tok = sample(logits, temperature)
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(tok)
@@ -439,7 +439,7 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
with Timing():
llama.model(Tensor([toks]), 0).realize() # NOTE: output logits are not used
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: output logits are not used
start_pos = len(toks)
else:
# non chat bot mode
@@ -474,10 +474,13 @@ After you are done speaking, output [EOS]. You are not Chad.
if args.timing: print("")
st = GlobalCounters.time_sum_s
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
logits = llama.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :]
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
with Timing("sync in ", enabled=args.timing):
tok = sample(logits, args.temperature)
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
# use the kv cache
start_pos = len(toks)

View File

@@ -19,6 +19,19 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_reshape_inside_plus1(self):
vi = Variable("i", 1, 10)
def f(a, jit=False, jit_ctx=None):
if jit: a = a.reshape(3, vi)
return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
a = Tensor.rand(3, i)
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
expected = f(a).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)

View File

@@ -24,16 +24,22 @@ class TinyJit:
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2:
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key)) # type: ignore
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
# NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly
if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
for k in variables.keys():
try: variables[k] = var_vals[k]
except KeyError: pass
prg(pargs, variables, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 1:

View File

@@ -87,9 +87,14 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs))
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
cl_bufs, wait_for = [], []
for x in bufs:
if x.__class__ is CLBuffer:
cl_bufs.append(x._buf)
if hasattr(x, "event"): wait_for.append(x.event)
else: cl_bufs.append(x)
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
if wait:
e.wait()
try:

View File

@@ -139,7 +139,7 @@ class ShapeTracker:
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
@property
def key(self) -> Tuple[View, ...]: return tuple(self.views)
def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys()))
# this is the real size (ish)
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])

View File

@@ -283,10 +283,11 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
return create_node(ret)
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
if isinstance(n, (int, NumNode)): return int(n)
if isinstance(n, Variable): return var_vals[n]
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
if n.__class__ is int: return n # type: ignore
if n.__class__ is NumNode: return n.b # type: ignore
if n.__class__ is Variable: return var_vals[n] # type: ignore
if n.__class__ is MulNode: return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals) # type: ignore
if n.__class__ is SumNode: return sum([sym_infer(s, var_vals) for s in n.nodes]) # type: ignore
raise NotImplementedError(n)
@functools.lru_cache(maxsize=None)
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"