diff --git a/examples/llama.py b/examples/llama.py index ff1473eab3..51c7c9328d 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -7,7 +7,7 @@ from pathlib import Path import functools, sys, argparse, json, os import numpy as np np.set_printoptions(linewidth=200) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from tinygrad.helpers import Timing, getenv, DEBUG, dtypes from tinygrad.ops import Device @@ -67,7 +67,7 @@ class Attention: self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) - def __call__(self, x:Tensor, cache_k:Tensor, cache_v:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor]: + def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]: bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) @@ -79,13 +79,13 @@ class Attention: if start_pos == 0: keys, values = xk, xv else: - assert cache_k.shape[0] > 0, "no cache" + assert cache_k is not None and cache_v is not None, "no cache" assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})" assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?" keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1) cache_k, cache_v = keys, values - keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize() + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1) return self.wo(attn).realize(), cache_k.realize(), cache_v.realize() @@ -110,62 +110,70 @@ class TransformerBlock: self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier) self.attention_norm = RMSNorm(dim, norm_eps) self.ffn_norm = RMSNorm(dim, norm_eps) - self.cache_k, self.cache_v = None, None - self.jitted_attention_norm = TinyJit(lambda x: self.attention_norm(x).realize()) - self.jitted_attn = TinyJit(self.attention.__call__) - self.jitted_norm_output = TinyJit(self.norm_output) - - def norm_output(self, x:Tensor, output:Tensor) -> Tensor: - h = x + output - return (h + self.feed_forward(self.ffn_norm(h))).realize() - - def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]): + def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None): bsz, seqlen, _ = x.shape - do_jit = getenv("JIT") and mask is None - if do_jit: + if getenv("JIT") and mask is None: + assert cache_k is not None and cache_v is not None, "no cache" pos = Variable("pos", 1, 1024) - self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3]) - self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3]) - output, cache_k, cache_v = self.jitted_attn(self.jitted_attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask) + cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3]) + cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3]) + # need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache + cache_k.lazydata.st.var_vals[pos] = start_pos + cache_v.lazydata.st.var_vals[pos] = start_pos + + # get only the part of freqs_cis that we are using. + freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4]))) + freqs_cis.lazydata.st.var_vals[pos] = start_pos else: - output, cache_k, cache_v = self.attention(self.attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask) + freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4]))) - # save the cache. with symbolic shape, cast it back to int shape so we have int shape in cache - self.cache_k = cache_k.reshape(cache_k.shape[0], start_pos+seqlen, cache_k.shape[2], cache_k.shape[3]).realize() - self.cache_v = cache_v.reshape(cache_v.shape[0], start_pos+seqlen, cache_v.shape[2], cache_v.shape[3]).realize() - - return self.jitted_norm_output(x, output) if do_jit else self.norm_output(x, output) + output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx) + h = x + output + return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize() class Transformer: def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None): self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)] + self.kv_caches = [(None, None) for _ in range(n_layers)] self.norm = RMSNorm(dim, norm_eps) self.tok_embeddings = Embedding(vocab_size, dim) self.output = linear(dim, vocab_size, bias=False) self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2)) - self.norm_output = lambda x: self.output(self.norm(x)) - self.jitted_tok_embeddings = TinyJit(lambda x: self.tok_embeddings(x).realize()) - self.jitted_norm_output = TinyJit(lambda x: self.norm_output(x).realize()) + self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize()) + self.postprocess_jitted = TinyJit(self.postprocess) + self.layers_jitted = [TinyJit(layer.__call__) for layer in self.layers] - def __call__(self, tokens:Tensor, start_pos:int): + def postprocess(self, x, temperature:Optional[float]): + logits = self.output(self.norm(x)) + if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize() + return logits.realize() + + def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None): _bsz, seqlen = tokens.shape - mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None - do_jit = getenv("JIT") and mask is None - - # get only the part of freqs_cis that we are using. - if do_jit: + if seqlen == 1 and getenv("JIT"): pos = Variable("pos", 1, 1024) - assert seqlen == 1, "seqlen > 1 not supported for JIT" freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4]))) freqs_cis.lazydata.st.var_vals[pos] = start_pos + h = self.tok_embeddings_jitted(tokens) + for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)): + h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos}) + # TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place + self.kv_caches[i] = (cache_k, cache_v) + return self.postprocess_jitted(h, temperature) else: freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4]))) - - h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens) - h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers]) - return self.jitted_norm_output(h) if do_jit else self.norm_output(h) + mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() + h = self.tok_embeddings(tokens) + for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers, self.kv_caches)): + # need this reshape back to int shape in conversational mode because jitted and unjitted calls share the same cache + if cache_k is not None and start_pos > 0: + cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3]) + cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3]) + h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=mask) + self.kv_caches[i] = (cache_k, cache_v) + return self.postprocess(h, temperature) # **** files and arguments **** @@ -206,15 +214,6 @@ MODEL_PARAMS = { } # **** helper functions **** -def sample(logits, temperature): - if temperature < 1e-6: - # so close to 0 we use argmax - return int(logits.argmax().numpy()) - else: - probs = (logits / temperature).softmax() - probs = probs.numpy().flatten() - return int(np.random.choice(len(probs), p=probs)) - def concat_weights(models): def convert(name) -> Tensor: disk_tensors = [model[name] for model in models] @@ -300,8 +299,9 @@ class LLaMa: toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt) start_pos = 0 for i in range(max_length): - logits = self.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :] - tok = sample(logits, temperature) + probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize() + probs_np = probs.numpy() + tok = int(np.random.choice(len(probs_np), p=probs_np)) start_pos = len(toks) toks.append(tok) @@ -439,7 +439,7 @@ After you are done speaking, output [EOS]. You are not Chad. print(f"Preparing KV cache for chatbot with personality {args.personality}...") with Timing(): - llama.model(Tensor([toks]), 0).realize() # NOTE: output logits are not used + llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: output logits are not used start_pos = len(toks) else: # non chat bot mode @@ -474,10 +474,13 @@ After you are done speaking, output [EOS]. You are not Chad. if args.timing: print("") st = GlobalCounters.time_sum_s - with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing): - logits = llama.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :] + with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+ + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ + f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=args.timing): + probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize() with Timing("sync in ", enabled=args.timing): - tok = sample(logits, args.temperature) + probs_np = probs.numpy() + tok = int(np.random.choice(len(probs_np), p=probs_np)) # use the kv cache start_pos = len(toks) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index bea43d3365..8bae172575 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -19,6 +19,19 @@ class TestSymbolicJit(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert len(jf.jit_cache) == 1 + def test_reshape_inside_plus1(self): + vi = Variable("i", 1, 10) + def f(a, jit=False, jit_ctx=None): + if jit: a = a.reshape(3, vi) + return (a+1).realize() + jf = TinyJit(f) + for i in range(1, 5): + a = Tensor.rand(3, i) + symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy() + expected = f(a).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert len(jf.jit_cache) == 1 + def test_add(self): def f(a, b): return (a+b).realize() jf = TinyJit(f) diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 00888d8fb4..2f29bc4a19 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -24,16 +24,22 @@ class TinyJit: def __call__(self, *args, **kwargs) -> Any: if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device # NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't - input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} + input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} assert len(input_rawbuffers) != 0, "no inputs to JIT" assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT" if self.cnt >= 2: - var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key)) # type: ignore + try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"] + except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor]) + if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key)) for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items(): - assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>" + assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}" + # NOTE: if we pass jit_ctx instead of using reshape to update the var_vals, we cannot compare the shapetracker directly + if "jit_ctx" not in kwargs: assert input_rawbuffers[input_name][1].views == expected_st.views, f"ShapeTracker.views mismatch in JIT, {input_rawbuffers[input_name][1].views} != {expected_st.views}" self.jit_cache[j][1][i] = input_rawbuffers[input_name][0] for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int] - for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v] + for k in variables.keys(): + try: variables[k] = var_vals[k] + except KeyError: pass prg(pargs, variables, jit=True) for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None elif self.cnt == 1: diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 0aba0e7958..ceb674c0fe 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -87,9 +87,14 @@ class CLProgram: def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: - if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if isinstance(x, CLBuffer) else np.int32 for x in bufs)) - cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs] - e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")]) + if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs)) + cl_bufs, wait_for = [], [] + for x in bufs: + if x.__class__ is CLBuffer: + cl_bufs.append(x._buf) + if hasattr(x, "event"): wait_for.append(x.event) + else: cl_bufs.append(x) + e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for) if wait: e.wait() try: diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 75a28601d6..4701fd1c8c 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -139,7 +139,7 @@ class ShapeTracker: def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape) @property - def key(self) -> Tuple[View, ...]: return tuple(self.views) + def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys())) # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index fe2805465e..22a49f8925 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -283,10 +283,11 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]): return create_node(ret) def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int: - if isinstance(n, (int, NumNode)): return int(n) - if isinstance(n, Variable): return var_vals[n] - if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals) - if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes) + if n.__class__ is int: return n # type: ignore + if n.__class__ is NumNode: return n.b # type: ignore + if n.__class__ is Variable: return var_vals[n] # type: ignore + if n.__class__ is MulNode: return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals) # type: ignore + if n.__class__ is SumNode: return sum([sym_infer(s, var_vals) for s in n.nodes]) # type: ignore raise NotImplementedError(n) @functools.lru_cache(maxsize=None) def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"