make jit in jit fail on first call

also clean up structure
This commit is contained in:
Chen-Yu Yang
2026-01-06 23:34:17 -06:00
parent 87f4bc5446
commit b6ddff67a8
3 changed files with 15 additions and 25 deletions

View File

@@ -526,9 +526,8 @@ class TestJitInsideJit(unittest.TestCase):
@TinyJit
def g(t): return f(t) * 3
# NOTE: first does not raise
g(Tensor([1])).realize()
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
# nested JIT raises on first call
with self.assertRaises(JitError):
g(Tensor([1])).realize()
class TestCopyInsideJit(unittest.TestCase):

View File

@@ -16,7 +16,6 @@ ERRORS RAISED (lower priority - at least users know):
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
duplicate_inputs_fail MED would need to handle aliasing in input_replace
nested_jit_fails_on_second_call MED could fail on first call instead of second
"""
import unittest
import numpy as np
@@ -71,17 +70,6 @@ class TestJitFootguns(unittest.TestCase):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second."""
@TinyJit
def inner(t): return t + 1
@TinyJit
def outer(t): return inner(t) * 3
self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works!
with self.assertRaises(RuntimeError):
outer(Tensor([2])).realize() # fails
def test_implicit_inputs_need_realize(self):
"""Closure tensors must be realized before JIT call."""
x = Tensor([0])

View File

@@ -284,24 +284,27 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
if not JIT or self.cnt == 0:
# jit ignore
def run(fxn):
assert self.fxn is not None
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
ret = fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
return ret
if not JIT: ret = run(self.fxn)
elif self.cnt == 0:
# jit warmup
if capturing: raise JitError("having TinyJit inside another TinyJit is not supported")
capturing.append(self)
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, CAPTURING=0):
try: ret = run(self.fxn)
finally: capturing.clear()
elif self.cnt == 1:
# jit capture
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._jit_cache: list[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
# TODO: should we always disable the memory planner here? it must be off for prune
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
try: ret = run(self.fxn)
except Exception as e: raise e
finally: capturing.clear()
jit_cache = self._jit_cache