From f0d7ad8aaafdb1daae0f0939897a48515fa77027 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 9 Jan 2024 16:14:55 -0500 Subject: [PATCH] fix gpt2 attention with start_pos = 0 (#3061) * fix gpt2 attention with start_pos size 1 test cases taken from ll_transformer branch * fix interpreted --- examples/gpt2.py | 10 +++++++--- test/test_symbolic_ops.py | 7 +++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index f81691cdcc..979cd49c45 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -22,7 +22,7 @@ class Attention: self.head_dim = dim // n_heads def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor: - if mask is not None: + if mask is not None or start_pos.val == 0: # no symbolic shape qkv when consuming prompts start_pos = start_pos.val @@ -35,8 +35,12 @@ class Attention: if not hasattr(self, "cache_kv"): self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype) - keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) - values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) + if start_pos > 0: + keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) + values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) + else: + keys = xk + values = xv # update the cache new_cache = Tensor.stack([keys, values]).pad((None, None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 38549d4c3a..b16503d65e 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -2,6 +2,7 @@ import unittest from tinygrad.shape.symbolic import Variable from tinygrad.helpers import getenv from tinygrad.tensor import Tensor +from examples.gpt2 import Attention import numpy as np @unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") @@ -54,6 +55,12 @@ class TestSymbolicOps(unittest.TestCase): # symbolic shape dropout is not supported self.test_attention(dropout_p=0.5) + def test_attention_pos_0_sz_1(self): + Attention(128, 8)(Tensor.ones(1, 1, 128), Variable("start_pos", 0, 128).bind(0), None) + + def test_attention_pos_0_sz_2(self): + Attention(128, 8)(Tensor.ones(1, 2, 128), Variable("start_pos", 0, 128).bind(0), None) + def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() for i in range(1, 5):