mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix gpt2 attention with start_pos = 0 (#3061)
* fix gpt2 attention with start_pos size 1 test cases taken from ll_transformer branch * fix interpreted
This commit is contained in:
@@ -2,6 +2,7 @@ import unittest
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from examples.gpt2 import Attention
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
@@ -54,6 +55,12 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
|
||||
def test_attention_pos_0_sz_1(self):
|
||||
Attention(128, 8)(Tensor.ones(1, 1, 128), Variable("start_pos", 0, 128).bind(0), None)
|
||||
|
||||
def test_attention_pos_0_sz_2(self):
|
||||
Attention(128, 8)(Tensor.ones(1, 2, 128), Variable("start_pos", 0, 128).bind(0), None)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
for i in range(1, 5):
|
||||
|
||||
Reference in New Issue
Block a user