test_attention_simple_view (#10092)

* test_attention_simple_view

* correct comment
This commit is contained in:
George Hotz
2025-04-28 20:01:22 -04:00
committed by GitHub
parent bda116d773
commit a2d0684fc1
2 changed files with 10 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad import Variable
from tinygrad.shape.shapetracker import View
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention
@@ -61,6 +62,14 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=False)
self.test_attention(imin=4, imax=5, use_symbolic=True)
# until this works, symbolic single kernel softmax won't
@unittest.expectedFailure
def test_attention_simple_view(self):
i = Variable("i", 2, 10)
v1 = View.create((2,4,1,i,i), ((i*4),i,0,0,1))
v2 = View.create((2,4,1,i,i,i), (((i*i)*4),(i*i),0,0,i,1))
self.assertIsNotNone(v1+v2)
def test_attention_training(self):
with Tensor.train():
self.test_attention(dropout_p=0.0)