gpt2 half for kvcache and output logits (#2630)

* gpt2 more half

* hlaf is fine after softmax
This commit is contained in:
chenyu
2023-12-05 16:54:56 -05:00
committed by GitHub
parent 0be5d16950
commit a63f48d3db

View File

@@ -28,6 +28,7 @@ class Attention:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
if HALF: x = x.half()
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
bsz, seqlen, n_heads, head_dim = xq.shape
@@ -47,7 +48,7 @@ class Attention:
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).cast(dtypes.float32).transpose(1, 2).reshape(bsz, seqlen, -1))
class FeedForward:
def __init__(self, dim, hidden_dim):
@@ -89,15 +90,15 @@ class Transformer:
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
if HALF:
# NOTE: converting this to half breaks GPT-2
#h = h.half()
h = h.half()
if mask is not None: mask = mask.half()
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
logits = self.lm_head(self.ln_f(h))
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize()
ret = (logits[:, -1, :] / (temperature+1e-10)).softmax()
return ret.half().realize() if HALF else ret.realize()
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
@@ -180,7 +181,7 @@ if __name__ == "__main__":
if HALF:
for l in get_state_dict(gpt2).values():
l.assign(l.cast(dtypes.float16).realize())
l.assign(l.half().realize())
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()