From 48c8d130aee97f22c428838545d1a18fd32ee460 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 29 Sep 2023 04:41:09 -0700 Subject: [PATCH] simpler GPT2 (#1941) * don't realize in gpt2 * simpler gpt2 --- examples/gpt2.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index 088e0ce37b..ff96e754cd 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -34,7 +34,7 @@ class Attention: self.dim = dim self.head_dim = dim // n_heads - def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tensor: + def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]) -> Tensor: xqkv = self.c_attn(x) xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)] xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)] @@ -70,7 +70,7 @@ class TransformerBlock: self.ln_1 = LayerNorm(dim, norm_eps) self.ln_2 = LayerNorm(dim, norm_eps) - def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], realize=True, jit_ctx:Optional[Dict[Variable,int]]=None): + def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]): if start_pos > 0 and mask is None and getenv("JIT"): start_pos_var = Variable("start_pos", 1, MAX_CONTEXT) cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3]) @@ -79,11 +79,9 @@ class TransformerBlock: cache_k.lazydata.var_vals[start_pos_var] = start_pos cache_v.lazydata.var_vals[start_pos_var] = start_pos - output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx) + output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask) h = x + output h = (h + self.mlp(self.ln_2(h))) - if realize: - return h.realize(), cache_k.realize(), cache_v.realize() return h, cache_k, cache_v class Transformer: @@ -100,29 +98,28 @@ class Transformer: self.postprocess_jitted = TinyJit(self.postprocess) self.h_jitted = [TinyJit(h.__call__) for h in self.h] - def embed(self, tokens:Tensor, pos:Tensor, realize=True): + def embed(self, tokens:Tensor, pos:Tensor): tok_emb = self.wte(tokens) pos_emb = self.wpe(pos) h = tok_emb + pos_emb if getenv("FP16"): h = h.half() - if not realize: - return h - return h.realize() + return h def postprocess(self, x, temperature:Optional[float]): logits = self.lm_head(self.ln_f(x)) - if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize() - return logits.realize() + if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten() + return logits @TinyJit - def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, jit_ctx:Optional[Dict[Variable,int]]=None, **kv_cache): - h = self.embed(tokens, pos, realize=False) + def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, **kv_cache): + h = self.embed(tokens, pos) for i, hi in enumerate(self.h): - h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None, realize=False, jit_ctx=jit_ctx) - for v in kv_cache.values(): v.realize() + h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None) - return self.postprocess(h, temperature), kv_cache + # don't realize until here + for v in kv_cache.values(): v.realize() + return self.postprocess(h, temperature).realize(), kv_cache def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None): _bsz, seqlen = tokens.shape @@ -130,7 +127,7 @@ class Transformer: start_pos_var = Variable("start_pos", 1, MAX_CONTEXT) pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen))) pos.lazydata.var_vals[start_pos_var] = start_pos - logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, jit_ctx={start_pos_var: start_pos}, **self.kv_caches) + logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches) return logit_or_softmax else: if start_pos == 0: @@ -143,7 +140,8 @@ class Transformer: h = embed(tokens, pos) for i, hi in enumerate(hs): h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'] = hi(h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'], start_pos=start_pos, mask=mask) - return postprocess(h, temperature) + for v in self.kv_caches.values(): v.realize() + return postprocess(h, temperature).realize() # **** files and arguments ****