mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
gpt2 half for kvcache and output logits (#2630)
* gpt2 more half * hlaf is fine after softmax
This commit is contained in:
@@ -28,6 +28,7 @@ class Attention:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
|
||||
if HALF: x = x.half()
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
@@ -47,7 +48,7 @@ class Attention:
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).cast(dtypes.float32).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim):
|
||||
@@ -89,15 +90,15 @@ class Transformer:
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
|
||||
if HALF:
|
||||
# NOTE: converting this to half breaks GPT-2
|
||||
#h = h.half()
|
||||
h = h.half()
|
||||
if mask is not None: mask = mask.half()
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize()
|
||||
ret = (logits[:, -1, :] / (temperature+1e-10)).softmax()
|
||||
return ret.half().realize() if HALF else ret.realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
@@ -180,7 +181,7 @@ if __name__ == "__main__":
|
||||
|
||||
if HALF:
|
||||
for l in get_state_dict(gpt2).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
l.assign(l.half().realize())
|
||||
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
|
||||
Reference in New Issue
Block a user