mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Revert "fix gpt2 with empty prompt" (#3101)
This commit is contained in:
@@ -47,7 +47,7 @@ class Attention:
|
||||
self.cache_kv.assign(new_cache).realize()
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim):
|
||||
@@ -70,7 +70,6 @@ class TransformerBlock:
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||
self.vocab_size = vocab_size
|
||||
self.wte = Embedding(vocab_size, dim)
|
||||
self.wpe = Embedding(max_seq_len, dim)
|
||||
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
@@ -96,16 +95,14 @@ class Transformer:
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos, mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h)).flatten(start_dim=1)
|
||||
# special case of empty prompt
|
||||
if logits.shape[1] == 0: logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))[:, -1, :]
|
||||
if temperature < 1e-6:
|
||||
ret = logits.argmax(-1)
|
||||
else:
|
||||
ret = (logits / temperature).softmax().multinomial()
|
||||
return ret.flatten().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
forward = (self.forward_jit if (isinstance(tokens, Variable) or tokens.shape[1] == 1) and getenv("JIT") else self.forward)
|
||||
return forward(tokens, start_pos, temperature)
|
||||
|
||||
@@ -55,9 +55,6 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
|
||||
def test_attention_pos_0_sz_0(self):
|
||||
Attention(128, 8)(Tensor.ones(1, 0, 128), Variable("start_pos", 0, 128).bind(0), None)
|
||||
|
||||
def test_attention_pos_0_sz_1(self):
|
||||
Attention(128, 8)(Tensor.ones(1, 1, 128), Variable("start_pos", 0, 128).bind(0), None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user