mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test_attention_simple_view (#10092)
* test_attention_simple_view * correct comment
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Variable
|
||||
from tinygrad.shape.shapetracker import View
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from examples.gpt2 import Attention
|
||||
@@ -61,6 +62,14 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
self.test_attention(imin=4, imax=5, use_symbolic=False)
|
||||
self.test_attention(imin=4, imax=5, use_symbolic=True)
|
||||
|
||||
# until this works, symbolic single kernel softmax won't
|
||||
@unittest.expectedFailure
|
||||
def test_attention_simple_view(self):
|
||||
i = Variable("i", 2, 10)
|
||||
v1 = View.create((2,4,1,i,i), ((i*4),i,0,0,1))
|
||||
v2 = View.create((2,4,1,i,i,i), (((i*i)*4),(i*i),0,0,i,1))
|
||||
self.assertIsNotNone(v1+v2)
|
||||
|
||||
def test_attention_training(self):
|
||||
with Tensor.train():
|
||||
self.test_attention(dropout_p=0.0)
|
||||
|
||||
Reference in New Issue
Block a user