fix gpt2 attention with start_pos = 0 (#3061)

* fix gpt2 attention with start_pos size 1

test cases taken from ll_transformer branch

* fix interpreted
This commit is contained in:
chenyu
2024-01-09 16:14:55 -05:00
committed by GitHub
parent 39b91131bc
commit f0d7ad8aaa
2 changed files with 14 additions and 3 deletions

View File

@@ -2,6 +2,7 @@ import unittest
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@@ -54,6 +55,12 @@ class TestSymbolicOps(unittest.TestCase):
# symbolic shape dropout is not supported
self.test_attention(dropout_p=0.5)
def test_attention_pos_0_sz_1(self):
Attention(128, 8)(Tensor.ones(1, 1, 128), Variable("start_pos", 0, 128).bind(0), None)
def test_attention_pos_0_sz_2(self):
Attention(128, 8)(Tensor.ones(1, 2, 128), Variable("start_pos", 0, 128).bind(0), None)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
for i in range(1, 5):