skip test_softmax_fusion tests if RANGEIFY==1 (#12310)

This commit is contained in:
Sieds Lykles
2025-09-27 05:57:40 +02:00
committed by GitHub
parent b98f1881ef
commit 29f0886395

View File

@@ -2,7 +2,7 @@ import unittest
import numpy as np
from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.dtype import DTypeLike, dtypes
from tinygrad.helpers import DEBUG, get_single_element
from tinygrad.helpers import DEBUG, get_single_element, RANGEIFY
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.device import is_dtype_supported
@@ -39,6 +39,7 @@ class TestFuse(unittest.TestCase):
np_multi = fxn(*args, **kwargs).numpy()
np.testing.assert_allclose(np_single, np_multi, atol=atol)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_fuse_norm(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a / a.mean(axis=1), a)
@@ -47,6 +48,7 @@ class TestFuse(unittest.TestCase):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.argmax(axis=-1), a)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_fuse_softmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1), a)
@@ -57,6 +59,7 @@ class TestFuse(unittest.TestCase):
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_fuse_softmax_dtype(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
@@ -64,6 +67,7 @@ class TestFuse(unittest.TestCase):
def test_fuse_arange_eye(self):
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_double_gemm(self):
N = 32
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
@@ -86,6 +90,7 @@ class TestFuse(unittest.TestCase):
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
self._test_fuse(embedding, a, atol=1e-5)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_attention_kernel_count(self):
wq = Tensor.empty(32, 32)
wk = Tensor.empty(32, 32)
@@ -98,6 +103,7 @@ class TestFuse(unittest.TestCase):
s = attn.schedule()
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_flash_attention(self):
BS = 4
HEADS = 2
@@ -165,6 +171,7 @@ class TestSoftmaxFusion(unittest.TestCase):
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_auto_softmax(self):
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):