diff --git a/examples/gpt2.py b/examples/gpt2.py index 979cd49c45..f9f9845334 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -47,7 +47,7 @@ class Attention: self.cache_kv.assign(new_cache).realize() xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) - return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)) + return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim)) class FeedForward: def __init__(self, dim, hidden_dim): @@ -70,6 +70,7 @@ class TransformerBlock: class Transformer: def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024): + self.vocab_size = vocab_size self.wte = Embedding(vocab_size, dim) self.wpe = Embedding(max_seq_len, dim) self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)] @@ -95,14 +96,16 @@ class Transformer: for hi in self.h: h = hi(h, start_pos, mask) - logits = self.lm_head(self.ln_f(h))[:, -1, :] + logits = self.lm_head(self.ln_f(h)).flatten(start_dim=1) + # special case of empty prompt + if logits.shape[1] == 0: logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device) + if temperature < 1e-6: ret = logits.argmax(-1) else: ret = (logits / temperature).softmax().multinomial() return ret.flatten().realize() - # TODO: fix empty token def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor: forward = (self.forward_jit if (isinstance(tokens, Variable) or tokens.shape[1] == 1) and getenv("JIT") else self.forward) return forward(tokens, start_pos, temperature) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index b16503d65e..29f914c344 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -55,6 +55,9 @@ class TestSymbolicOps(unittest.TestCase): # symbolic shape dropout is not supported self.test_attention(dropout_p=0.5) + def test_attention_pos_0_sz_0(self): + Attention(128, 8)(Tensor.ones(1, 0, 128), Variable("start_pos", 0, 128).bind(0), None) + def test_attention_pos_0_sz_1(self): Attention(128, 8)(Tensor.ones(1, 1, 128), Variable("start_pos", 0, 128).bind(0), None)