llama take int and convert to Variable internally (#2284)

This commit is contained in:
chenyu
2023-11-12 17:11:37 -05:00
committed by GitHub
parent 123ea051e6
commit a72b370066
3 changed files with 18 additions and 18 deletions

View File

@@ -4,10 +4,10 @@
#typeguard.importhook.install_import_hook('tinygrad')
from pathlib import Path
import functools, sys, argparse, json, os
import sys, argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Union
from tinygrad.helpers import Timing, getenv, DEBUG, dtypes, CI
from tinygrad.ops import Device
@@ -16,7 +16,7 @@ from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.shape.symbolic import Variable
MAX_CONTEXT = 1024
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
@@ -70,11 +70,7 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
if mask is not None:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
@@ -121,7 +117,7 @@ class TransformerBlock:
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Variable, freqs_cis:Tensor, mask:Optional[Tensor]):
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
@@ -134,10 +130,10 @@ class Transformer:
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
_bsz, seqlen = tokens.shape
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos.val+1).realize() if seqlen > 1 else None
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
h = self.tok_embeddings(tokens)
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
@@ -145,7 +141,11 @@ class Transformer:
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and JIT:
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos), temperature)
return self.forward(tokens, start_pos, temperature)
# **** files and arguments ****
MODEL_PARAMS = {
@@ -308,7 +308,7 @@ class LLaMa:
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
start_pos = 0
for i in range(max_length):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
@@ -501,7 +501,7 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
with Timing():
llama.model(Tensor([toks]), Variable("start_pos", 0, MAX_CONTEXT).bind(0), args.temperature).realize() # NOTE: output logits are not used
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used
start_pos = len(toks)
else:
# non chat bot mode
@@ -540,7 +540,7 @@ After you are done speaking, output [EOS]. You are not Chad.
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), args.temperature).realize()
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))

View File

@@ -89,8 +89,8 @@ class TestInferenceMinKernels(unittest.TestCase):
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(98, var_vals={Variable("start_pos", 0, 1024): 0}):
model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 1024).bind(0)).realize()
with CLCache(98):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):

View File

@@ -38,7 +38,7 @@ class TestLLaMASpeed(unittest.TestCase):
if empty_method_cache: Device[Device.DEFAULT].method_cache.clear()
tms = [time.perf_counter()]
for i in range(10):
model(Tensor([[2]]), i).realize()
model(Tensor([[1,2,3,4]]), i).realize()
tms.append(time.perf_counter())
timings = [(tms[i+1]-tms[i])*1000 for i in range(len(tms)-1)]
print(f"{st:15s} mean runtime: {sum(timings)/len(timings):7.2f}ms, runs: ", ", ".join(f'{x:7.2f}' for x in timings))