mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
skip test_softmax_fusion tests if RANGEIFY==1 (#12310)
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, Context, Device
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
from tinygrad.helpers import DEBUG, get_single_element
|
||||
from tinygrad.helpers import DEBUG, get_single_element, RANGEIFY
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -39,6 +39,7 @@ class TestFuse(unittest.TestCase):
|
||||
np_multi = fxn(*args, **kwargs).numpy()
|
||||
np.testing.assert_allclose(np_single, np_multi, atol=atol)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_fuse_norm(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a / a.mean(axis=1), a)
|
||||
@@ -47,6 +48,7 @@ class TestFuse(unittest.TestCase):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.argmax(axis=-1), a)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_fuse_softmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.softmax(axis=-1), a)
|
||||
@@ -57,6 +59,7 @@ class TestFuse(unittest.TestCase):
|
||||
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_fuse_softmax_dtype(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
|
||||
@@ -64,6 +67,7 @@ class TestFuse(unittest.TestCase):
|
||||
def test_fuse_arange_eye(self):
|
||||
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_double_gemm(self):
|
||||
N = 32
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
@@ -86,6 +90,7 @@ class TestFuse(unittest.TestCase):
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
self._test_fuse(embedding, a, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_attention_kernel_count(self):
|
||||
wq = Tensor.empty(32, 32)
|
||||
wk = Tensor.empty(32, 32)
|
||||
@@ -98,6 +103,7 @@ class TestFuse(unittest.TestCase):
|
||||
s = attn.schedule()
|
||||
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
@@ -165,6 +171,7 @@ class TestSoftmaxFusion(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_auto_softmax(self):
|
||||
print("*** softmax ***")
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
|
||||
Reference in New Issue
Block a user