mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
@@ -34,7 +34,7 @@ class Attention:
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tensor:
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]) -> Tensor:
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
|
||||
@@ -70,7 +70,7 @@ class TransformerBlock:
|
||||
self.ln_1 = LayerNorm(dim, norm_eps)
|
||||
self.ln_2 = LayerNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor], realize=True, jit_ctx:Optional[Dict[Variable,int]]=None):
|
||||
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, mask:Optional[Tensor]):
|
||||
if start_pos > 0 and mask is None and getenv("JIT"):
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3])
|
||||
@@ -79,11 +79,9 @@ class TransformerBlock:
|
||||
cache_k.lazydata.var_vals[start_pos_var] = start_pos
|
||||
cache_v.lazydata.var_vals[start_pos_var] = start_pos
|
||||
|
||||
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx)
|
||||
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask)
|
||||
h = x + output
|
||||
h = (h + self.mlp(self.ln_2(h)))
|
||||
if realize:
|
||||
return h.realize(), cache_k.realize(), cache_v.realize()
|
||||
return h, cache_k, cache_v
|
||||
|
||||
class Transformer:
|
||||
@@ -100,29 +98,28 @@ class Transformer:
|
||||
self.postprocess_jitted = TinyJit(self.postprocess)
|
||||
self.h_jitted = [TinyJit(h.__call__) for h in self.h]
|
||||
|
||||
def embed(self, tokens:Tensor, pos:Tensor, realize=True):
|
||||
def embed(self, tokens:Tensor, pos:Tensor):
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(pos)
|
||||
h = tok_emb + pos_emb
|
||||
if getenv("FP16"): h = h.half()
|
||||
if not realize:
|
||||
return h
|
||||
return h.realize()
|
||||
return h
|
||||
|
||||
def postprocess(self, x, temperature:Optional[float]):
|
||||
logits = self.lm_head(self.ln_f(x))
|
||||
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
return logits.realize()
|
||||
if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten()
|
||||
return logits
|
||||
|
||||
@TinyJit
|
||||
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, jit_ctx:Optional[Dict[Variable,int]]=None, **kv_cache):
|
||||
h = self.embed(tokens, pos, realize=False)
|
||||
def run_all_layers(self, tokens:Tensor, pos:Tensor, start_pos:int, temperature:float, **kv_cache):
|
||||
h = self.embed(tokens, pos)
|
||||
|
||||
for i, hi in enumerate(self.h):
|
||||
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None, realize=False, jit_ctx=jit_ctx)
|
||||
for v in kv_cache.values(): v.realize()
|
||||
h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'] = hi(h, kv_cache[f'cache_k{i}'], kv_cache[f'cache_v{i}'], start_pos=start_pos, mask=None)
|
||||
|
||||
return self.postprocess(h, temperature), kv_cache
|
||||
# don't realize until here
|
||||
for v in kv_cache.values(): v.realize()
|
||||
return self.postprocess(h, temperature).realize(), kv_cache
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]=None):
|
||||
_bsz, seqlen = tokens.shape
|
||||
@@ -130,7 +127,7 @@ class Transformer:
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
|
||||
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var + seqlen)))
|
||||
pos.lazydata.var_vals[start_pos_var] = start_pos
|
||||
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, jit_ctx={start_pos_var: start_pos}, **self.kv_caches)
|
||||
logit_or_softmax, self.kv_caches = self.run_all_layers(tokens, pos, start_pos=start_pos, temperature=temperature, **self.kv_caches)
|
||||
return logit_or_softmax
|
||||
else:
|
||||
if start_pos == 0:
|
||||
@@ -143,7 +140,8 @@ class Transformer:
|
||||
h = embed(tokens, pos)
|
||||
for i, hi in enumerate(hs):
|
||||
h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'] = hi(h, self.kv_caches[f'cache_k{i}'], self.kv_caches[f'cache_v{i}'], start_pos=start_pos, mask=mask)
|
||||
return postprocess(h, temperature)
|
||||
for v in self.kv_caches.values(): v.realize()
|
||||
return postprocess(h, temperature).realize()
|
||||
|
||||
# **** files and arguments ****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user