diff --git a/test/backend/test_jit.py b/test/backend/test_jit.py index 5f96286c03..d2fc3d16bb 100644 --- a/test/backend/test_jit.py +++ b/test/backend/test_jit.py @@ -39,6 +39,18 @@ class TestJit(unittest.TestCase): def add(a, b): return (a+b).realize() _simple_test(add) + def test_jitbeam_triggers_beam(self): + from unittest.mock import patch + from tinygrad.helpers import getenv as _getenv + @TinyJit + def add(a, b): return (a+b).realize() + a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize() + with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam: + add(a, b) + assert mock_beam.call_count == 0 + with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b) + assert mock_beam.call_count == 1 + def test_simple_jit_reset(self): @TinyJit def add(a, b): return (a+b).realize() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 7d10a9d2e1..6f4d0cdfac 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -322,12 +322,11 @@ class TinyJit(Generic[ReturnType]): assert self.fxn is not None if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}") self._linears: list[UOp] = [] - with Context(BEAM=getenv("JITBEAM", BEAM.value)): - capturing.append(self) - try: - ret = self.fxn(*args, **kwargs) - if len(params:=get_parameters(ret)): Tensor.realize(*params) - finally: capturing.clear() + capturing.append(self) + try: + ret = self.fxn(*args, **kwargs) + if len(params:=get_parameters(ret)): Tensor.realize(*params) + finally: capturing.clear() if not len(self._linears): raise JitError("didn't JIT anything!") _check_no_non_tensor_return(ret) if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs") @@ -344,7 +343,7 @@ class TinyJit(Generic[ReturnType]): ei.run(var_vals, jit=True) del onetime_linear - jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)] + with Context(BEAM=getenv("JITBEAM", BEAM.value)): jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)] del big_linear # track inputs that are views of buffers