fp16 in gpt2 attention (#2491)

* fp16 in gpt2 attention

* HALF
This commit is contained in:
chenyu
2023-11-28 19:27:03 -05:00
committed by GitHub
parent 847f0a02b1
commit a739c6646e

View File

@@ -13,6 +13,7 @@ from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")
class Attention:
def __init__(self, dim, n_heads):
@@ -34,6 +35,9 @@ class Attention:
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
if HALF:
self.cache_k = self.cache_k.half()
self.cache_v = self.cache_v.half()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
@@ -77,14 +81,21 @@ class Transformer:
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
_bsz, seqlen = tokens.shape
# NOTE: cannot convert token indices into half due to precision
tok_emb = self.wte(tokens)
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
h = tok_emb + pos_emb
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
if HALF:
h = h.half()
if mask is not None: mask = mask.half()
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
logits = self.lm_head(self.ln_f(h))
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize()
# TODO: fix empty token
@@ -156,7 +167,6 @@ if __name__ == "__main__":
parser.add_argument('--seed', type=int, help="Set the random seed")
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
args = parser.parse_args()
@@ -167,7 +177,7 @@ if __name__ == "__main__":
print(f"using {args.model_size}")
gpt2 = GPT2.build(args.model_size)
if args.fp16:
if HALF:
for l in get_state_dict(gpt2).values():
l.assign(l.cast(dtypes.float16).realize())