diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 660c567e0e..54a801f72c 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad import Tensor, GlobalCounters, Context, Device from tinygrad.dtype import DTypeLike, dtypes -from tinygrad.helpers import DEBUG, get_single_element +from tinygrad.helpers import DEBUG, get_single_element, RANGEIFY from tinygrad.engine.realize import lower_schedule_item from tinygrad.device import is_dtype_supported @@ -39,6 +39,7 @@ class TestFuse(unittest.TestCase): np_multi = fxn(*args, **kwargs).numpy() np.testing.assert_allclose(np_single, np_multi, atol=atol) + @unittest.skipIf(01") def test_fuse_norm(self): a = Tensor.rand(50,50).realize() self._test_fuse(lambda a: a / a.mean(axis=1), a) @@ -47,6 +48,7 @@ class TestFuse(unittest.TestCase): a = Tensor.rand(50,50).realize() self._test_fuse(lambda a: a.argmax(axis=-1), a) + @unittest.skipIf(01") def test_fuse_softmax(self): a = Tensor.rand(50,50).realize() self._test_fuse(lambda a: a.softmax(axis=-1), a) @@ -57,6 +59,7 @@ class TestFuse(unittest.TestCase): self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True) @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") + @unittest.skipIf(01") def test_fuse_softmax_dtype(self): a = Tensor.rand(50,50).realize() self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4) @@ -64,6 +67,7 @@ class TestFuse(unittest.TestCase): def test_fuse_arange_eye(self): self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10)) + @unittest.skipIf(01") def test_double_gemm(self): N = 32 with Context(TRACK_MATCH_STATS=0, DEBUG=0): @@ -86,6 +90,7 @@ class TestFuse(unittest.TestCase): return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) self._test_fuse(embedding, a, atol=1e-5) + @unittest.skipIf(01") def test_attention_kernel_count(self): wq = Tensor.empty(32, 32) wk = Tensor.empty(32, 32) @@ -98,6 +103,7 @@ class TestFuse(unittest.TestCase): s = attn.schedule() self.assertEqual(len(s), 4) # 3 matmul and 1 attention + @unittest.skipIf(01") def test_flash_attention(self): BS = 4 HEADS = 2 @@ -165,6 +171,7 @@ class TestSoftmaxFusion(unittest.TestCase): np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7) + @unittest.skipIf(01") def test_auto_softmax(self): print("*** softmax ***") with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):