mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -13,6 +13,7 @@ from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||
HALF = getenv("HALF")
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads):
|
||||
@@ -34,6 +35,9 @@ class Attention:
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
if HALF:
|
||||
self.cache_k = self.cache_k.half()
|
||||
self.cache_v = self.cache_v.half()
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
@@ -77,14 +81,21 @@ class Transformer:
|
||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
# NOTE: cannot convert token indices into half due to precision
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
|
||||
if HALF:
|
||||
h = h.half()
|
||||
if mask is not None: mask = mask.half()
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
@@ -156,7 +167,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--seed', type=int, help="Set the random seed")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
||||
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -167,7 +177,7 @@ if __name__ == "__main__":
|
||||
print(f"using {args.model_size}")
|
||||
gpt2 = GPT2.build(args.model_size)
|
||||
|
||||
if args.fp16:
|
||||
if HALF:
|
||||
for l in get_state_dict(gpt2).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user