diff --git a/test/helpers.py b/test/helpers.py index 832f530ff8..542e0ea80c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -15,12 +15,14 @@ def derandomize_model(model): p.realize() def assert_jit_cache_len(fxn, expected_len): - assert len(fxn.jit_cache) > 0 + if not fxn.jit_cache: + assert expected_len == 0, expected_len + return # until we have a better way of typing the prg in ExecItem if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'): - assert len(fxn.jit_cache) == expected_len + assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache) else: - assert len(fxn.jit_cache) == 1 + assert len(fxn.jit_cache) == 1, len(fxn.jit_cache) # until we have a better way of typing the prg in ExecItem assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len diff --git a/test/test_jit.py b/test/test_jit.py index f23c3d9a80..a2bef07ad3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6,7 +6,7 @@ from test.helpers import assert_jit_cache_len from tinygrad.tensor import Tensor from tinygrad.engine.jit import TinyJit from tinygrad.device import Device -from tinygrad.helpers import CI +from tinygrad.helpers import CI, Context from tinygrad.dtype import dtypes def _simple_test(add, extract=lambda x: x, N=10): @@ -66,6 +66,17 @@ class TestJit(unittest.TestCase): b = Tensor.randn(10, 10) add(a, b) + def test_jit_zero_does_not_jit(self): + @TinyJit + def add(a, b): return (a+b).realize() + with Context(JIT=0): + for i in range(5): + a = Tensor([i]) + b = Tensor([i]) + c = add(a, b) + np.testing.assert_allclose(c.numpy(), 2*i) + assert_jit_cache_len(add, 0) + def test_jit_shape_mismatch(self): @TinyJit def add(a, b): return (a+b).realize() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 9972516517..ce6b68302d 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -144,7 +144,7 @@ class TinyJit(Generic[ReturnType]): var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \ [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))]) st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device] - if self.cnt == 0: + if not JIT or self.cnt == 0: # jit ignore with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value): self.ret = self.fxn(*args, **kwargs)