simpler GPT2 (#1941)

* don't realize in gpt2

* simpler gpt2
This commit is contained in:
George Hotz
2023-09-29 04:41:09 -07:00
committed by GitHub
parent 81cb120b0f
commit 48c8d130ae

View File

@@ -34,7 +34,7 @@ class Attention:
self.dim = dim
self.head_dim = dim // n_heads
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tensor:
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]) -> Tensor:
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
@@ -70,7 +70,7 @@ class TransformerBlock:
self.ln_1 = LayerNorm(dim, norm_eps)
self.ln_2 = LayerNorm(dim, norm_eps)
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], realize=True, jit_ctx:Optional[Dict[Variable,int]]=None):
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]):
if start_pos > 0 and mask is None and getenv("JIT"):
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3])
@@ -79,11 +79,9 @@ class TransformerBlock:
cache_k.lazydata.var_vals[start_pos_var] = start_pos
cache_v.lazydata.var_vals[start_pos_var] = start_pos
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx)
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask)
h = x + output
h = (h + self.mlp(self.ln_2(h)))
if realize:
return h.realize(), cache_k.realize(), cache_v.realize()
return h, cache_k, cache_v
class Transformer:
@@ -100,29 +98,28 @@ class Transformer:
self.postprocess_jitted = TinyJit(self.postprocess)
self.h_jitted = [TinyJit(h.__call__) for h in self.h]
def embed(self, tokens:Tensor, pos:Tensor, realize=True):
def embed(self, tokens:Tensor, pos:Tensor):
tok_emb = self.wte(tokens)
pos_emb = self.wpe(pos)
h = tok_emb + pos_emb
if getenv("FP16"): h = h.half()
if not realize:
return h
return h.realize()
return h
def postprocess(self, x, temperature:Optional[float]):
logits = self.lm_head(self.ln_f(x))
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
return logits.realize()
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten()
return logits
@TinyJit
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, jit_ctx:Optional[Dict[Variable,int]]=None, **kv_cache):
h = self.embed(tokens, pos, realize=False)
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, **kv_cache):
h = self.embed(tokens, pos)
for i, hi in enumerate(self.h):
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None, realize=False, jit_ctx=jit_ctx)
for v in kv_cache.values(): v.realize()
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None)
return self.postprocess(h, temperature), kv_cache
# don't realize until here
for v in kv_cache.values(): v.realize()
return self.postprocess(h, temperature).realize(), kv_cache
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
_bsz, seqlen = tokens.shape
@@ -130,7 +127,7 @@ class Transformer:
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
pos.lazydata.var_vals[start_pos_var] = start_pos
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, jit_ctx={start_pos_var: start_pos}, **self.kv_caches)
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
return logit_or_softmax
else:
if start_pos == 0:
@@ -143,7 +140,8 @@ class Transformer:
h = embed(tokens, pos)
for i, hi in enumerate(hs):
h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'] = hi(h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'], start_pos=start_pos, mask=mask)
return postprocess(h, temperature)
for v in self.kv_caches.values(): v.realize()
return postprocess(h, temperature).realize()
# **** files and arguments ****