skip test_jit_batch_split if JIT >= 2 (#7561)

* skip test_jit_batch_split if JIT >= 2

only test graphs

* 1600
This commit is contained in:
chenyu
2024-11-05 14:59:04 -05:00
committed by GitHub
parent f2fa183651
commit c805e3fff5
2 changed files with 3 additions and 3 deletions

View File

@@ -7,7 +7,7 @@ from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI, Context
from tinygrad.helpers import CI, Context, JIT
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
@@ -352,7 +352,7 @@ class TestJit(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
def test_jit_batch_split(self):
if Device[Device.DEFAULT].graph is None: raise unittest.SkipTest("only test graphs")
if Device[Device.DEFAULT].graph is None or JIT >= 2: raise unittest.SkipTest("only test graphs")
# Create long jit with 83 kernels.
def f(a, b, c, d, e):

View File

@@ -1613,7 +1613,7 @@ class TestIndexing(unittest.TestCase):
return a.item()
r, et = timeit(f, a)
self.assertEqual(r, val)
self.assertLess(et, 1400)
self.assertLess(et, 1600)
def test_no_rewrite_elementwise(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]