diff --git a/test/test_jit.py b/test/test_jit.py index c53651d5d1..8cb33191b6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -526,9 +526,8 @@ class TestJitInsideJit(unittest.TestCase): @TinyJit def g(t): return f(t) * 3 - # NOTE: first does not raise - g(Tensor([1])).realize() - with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"): + # nested JIT raises on first call + with self.assertRaises(JitError): g(Tensor([1])).realize() class TestCopyInsideJit(unittest.TestCase): diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 704653d143..9d1b9b71a3 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -16,7 +16,6 @@ ERRORS RAISED (lower priority - at least users know): non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature duplicate_inputs_fail MED would need to handle aliasing in input_replace - nested_jit_fails_on_second_call MED could fail on first call instead of second """ import unittest import numpy as np @@ -71,17 +70,6 @@ class TestJitFootguns(unittest.TestCase): a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize() np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i]) - def test_nested_jit_fails_on_second_call(self): - """Nested JIT works on first call but fails on second.""" - @TinyJit - def inner(t): return t + 1 - @TinyJit - def outer(t): return inner(t) * 3 - - self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works! - with self.assertRaises(RuntimeError): - outer(Tensor([2])).realize() # fails - def test_implicit_inputs_need_realize(self): """Closure tensors must be realized before JIT call.""" x = Tensor([0]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 1d5d9edb73..fbc321228a 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -284,24 +284,27 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs) - if not JIT or self.cnt == 0: - # jit ignore + def run(fxn): assert self.fxn is not None - with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value): - ret = self.fxn(*args, **kwargs) - if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) + ret = fxn(*args, **kwargs) + if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) + return ret + if not JIT: ret = run(self.fxn) + elif self.cnt == 0: + # jit warmup + if capturing: raise JitError("having TinyJit inside another TinyJit is not supported") + capturing.append(self) + with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, CAPTURING=0): + try: ret = run(self.fxn) + finally: capturing.clear() elif self.cnt == 1: # jit capture - assert self.fxn is not None - if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}") self._jit_cache: list[ExecItem] = [] self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary() # TODO: should we always disable the memory planner here? it must be off for prune with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)): capturing.append(self) - try: - ret = self.fxn(*args, **kwargs) - if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) + try: ret = run(self.fxn) except Exception as e: raise e finally: capturing.clear() jit_cache = self._jit_cache