test_attention_simple_view (#10092)

* test_attention_simple_view

* correct comment
This commit is contained in:
George Hotz
2025-04-28 20:01:22 -04:00
committed by GitHub
parent bda116d773
commit a2d0684fc1
2 changed files with 10 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad import Variable
from tinygrad.shape.shapetracker import View
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention
@@ -61,6 +62,14 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=False)
self.test_attention(imin=4, imax=5, use_symbolic=True)
# until this works, symbolic single kernel softmax won't
@unittest.expectedFailure
def test_attention_simple_view(self):
i = Variable("i", 2, 10)
v1 = View.create((2,4,1,i,i), ((i*4),i,0,0,1))
v2 = View.create((2,4,1,i,i,i), (((i*i)*4),(i*i),0,0,i,1))
self.assertIsNotNone(v1+v2)
def test_attention_training(self):
with Tensor.train():
self.test_attention(dropout_p=0.0)

View File

@@ -41,6 +41,7 @@ symbolic_simple = PatternMatcher([
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# x*0 -> 0 or 0*x -> 0