mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
skip test_jit_batch_split if JIT >= 2 (#7561)
* skip test_jit_batch_split if JIT >= 2 only test graphs * 1600
This commit is contained in:
@@ -7,7 +7,7 @@ from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.helpers import CI, Context, JIT
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
@@ -352,7 +352,7 @@ class TestJit(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
|
||||
def test_jit_batch_split(self):
|
||||
if Device[Device.DEFAULT].graph is None: raise unittest.SkipTest("only test graphs")
|
||||
if Device[Device.DEFAULT].graph is None or JIT >= 2: raise unittest.SkipTest("only test graphs")
|
||||
|
||||
# Create long jit with 83 kernels.
|
||||
def f(a, b, c, d, e):
|
||||
|
||||
@@ -1613,7 +1613,7 @@ class TestIndexing(unittest.TestCase):
|
||||
return a.item()
|
||||
r, et = timeit(f, a)
|
||||
self.assertEqual(r, val)
|
||||
self.assertLess(et, 1400)
|
||||
self.assertLess(et, 1600)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
|
||||
Reference in New Issue
Block a user