mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test_attention_simple_view (#10092)
* test_attention_simple_view * correct comment
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Variable
|
||||
from tinygrad.shape.shapetracker import View
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from examples.gpt2 import Attention
|
||||
@@ -61,6 +62,14 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
self.test_attention(imin=4, imax=5, use_symbolic=False)
|
||||
self.test_attention(imin=4, imax=5, use_symbolic=True)
|
||||
|
||||
# until this works, symbolic single kernel softmax won't
|
||||
@unittest.expectedFailure
|
||||
def test_attention_simple_view(self):
|
||||
i = Variable("i", 2, 10)
|
||||
v1 = View.create((2,4,1,i,i), ((i*4),i,0,0,1))
|
||||
v2 = View.create((2,4,1,i,i,i), (((i*i)*4),(i*i),0,0,i,1))
|
||||
self.assertIsNotNone(v1+v2)
|
||||
|
||||
def test_attention_training(self):
|
||||
with Tensor.train():
|
||||
self.test_attention(dropout_p=0.0)
|
||||
|
||||
@@ -41,6 +41,7 @@ symbolic_simple = PatternMatcher([
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
|
||||
Reference in New Issue
Block a user