make embedding and GPT-2 fast (#1631)

* make embedding fast

* jit more, variable shape support

* print mem bw
This commit is contained in:
George Hotz
2023-08-22 15:14:38 -07:00
committed by GitHub
parent a7752ad65d
commit 643cbdfd50
5 changed files with 52 additions and 16 deletions

View File

@@ -15,7 +15,7 @@ from tinygrad.nn import Embedding, Linear
from tinygrad.jit import TinyJit
from tinygrad.shape.symbolic import Variable
from examples.llama import sample
MAX_CONTEXT = 128
class LayerNorm:
def __init__(self, dim, eps=1e-5):
@@ -81,7 +81,7 @@ class TransformerBlock:
if start_pos > 0 and mask is None and getenv("JIT"):
seqlen = x.shape[1]
pos = Variable("pos", 1, 128) # max context
pos = Variable("pos", 1, MAX_CONTEXT)
self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3])
self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3])
@@ -104,18 +104,34 @@ class Transformer:
self.ln_f = LayerNorm(dim, norm_eps)
self.lm_head = linear(dim, vocab_size, bias=False)
def __call__(self, tokens:Tensor, start_pos:int):
_bsz, seqlen = tokens.shape
self.embed_jitted = TinyJit(self.embed)
self.postprocess_jitted = TinyJit(self.postprocess)
def embed(self, tokens, pos):
tok_emb = self.wte(tokens)
pos = Tensor.arange(start_pos, start_pos + seqlen).reshape(1, -1)
pos_emb = self.wpe(pos)
h = tok_emb + pos_emb
return h.realize()
# get only the part we are using. making it contiguous avoids more kernel calls
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
h = h.sequential([functools.partial(layer, start_pos=start_pos, mask=mask) for layer in self.h])
h = self.ln_f(h)
return self.lm_head(h)
def postprocess(self, x, temperature:Optional[float]):
logits = self.lm_head(self.ln_f(x))
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
return logits.realize()
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]):
_bsz, seqlen = tokens.shape
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var+seqlen)))
pos.lazydata.st.var_vals[start_pos_var] = start_pos
h = self.embed_jitted(tokens, pos).sequential([functools.partial(layer, start_pos=start_pos, mask=None) for layer in self.h])
return self.postprocess_jitted(h, temperature)
else:
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos, start_pos+seqlen)))
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize()
h = self.embed(tokens, pos).sequential([functools.partial(layer, start_pos=start_pos, mask=mask) for layer in self.h])
return self.postprocess(h, temperature)
# **** files and arguments ****
@@ -163,10 +179,13 @@ class GPT2:
GlobalCounters.reset()
if args.timing: print("")
st = GlobalCounters.time_sum_s
with Timing(f"ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU, {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB") if DEBUG else None, enabled=timing):
logits = self.model(Tensor([toks[start_pos:]]), start_pos)[:, -1, :].realize()
with Timing(f"ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=timing):
probs = self.model(Tensor([toks[start_pos:]]), start_pos, temperature)
with Timing("sync in ", enabled=timing):
tok = sample(logits, temperature)
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(tok)
output = self.tokenizer.decode(toks)

View File

@@ -0,0 +1,8 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding
if __name__ == "__main__":
vocab_size = 50257
dim = 128
test = Embedding(vocab_size, dim)
ret = test(Tensor([[1,2,3]])).numpy()

View File

@@ -213,6 +213,11 @@ class Linearizer:
# print early
if DEBUG >= 5: self.printbufs("early")
def has_variable_shape(self) -> bool:
for b in self.bufs:
if any(not isinstance(x, int) for x in b.st.shape): return True
return False
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]

View File

@@ -72,7 +72,7 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
if global_db is not None and skey in global_db:
choice = global_db[skey]
elif any(not isinstance(x, int) for x in k.full_shape):
elif k.has_variable_shape():
# don't optimize variable shapes
choice = "BASELINE"
else:
@@ -260,6 +260,10 @@ def hand_coded_optimizations(k:Linearizer):
# no more opt if we are grouping
if k.group_for_reduce: return
# no more opt if there's non ints in any shapes
# TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT
if k.has_variable_shape(): return
# **** below this line need to be optional and benchmarked ****
# potentially do more upcasts of non reduce axes based on a heuristic

View File

@@ -120,5 +120,5 @@ class Embedding:
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size)
return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight