diff --git a/examples/gpt2.py b/examples/gpt2.py index d15cc441c6..7306d5d7a8 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -28,6 +28,7 @@ class Attention: # no symbolic shape qkv when consuming prompts start_pos = start_pos.val + if HALF: x = x.half() xqkv = self.c_attn(x) xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)] bsz, seqlen, n_heads, head_dim = xq.shape @@ -47,7 +48,7 @@ class Attention: self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize() xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) - return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)) + return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).cast(dtypes.float32).transpose(1, 2).reshape(bsz, seqlen, -1)) class FeedForward: def __init__(self, dim, hidden_dim): @@ -89,15 +90,15 @@ class Transformer: mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None if HALF: - # NOTE: converting this to half breaks GPT-2 - #h = h.half() + h = h.half() if mask is not None: mask = mask.half() for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask) logits = self.lm_head(self.ln_f(h)) # NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead - return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize() + ret = (logits[:, -1, :] / (temperature+1e-10)).softmax() + return ret.half().realize() if HALF else ret.realize() # TODO: fix empty token def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor: @@ -180,7 +181,7 @@ if __name__ == "__main__": if HALF: for l in get_state_dict(gpt2).values(): - l.assign(l.cast(dtypes.float16).realize()) + l.assign(l.half().realize()) if args.benchmark != -1: gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()