fix jitbeam not triggered (#15424)

* um

* beam

* x

* f
This commit is contained in:
nimlgen
2026-03-23 15:34:59 +08:00
committed by GitHub
parent fd3559103b
commit c74fa9bbe1
2 changed files with 18 additions and 7 deletions

View File

@@ -39,6 +39,18 @@ class TestJit(unittest.TestCase):
def add(a, b): return (a+b).realize()
_simple_test(add)
def test_jitbeam_triggers_beam(self):
from unittest.mock import patch
from tinygrad.helpers import getenv as _getenv
@TinyJit
def add(a, b): return (a+b).realize()
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
add(a, b)
assert mock_beam.call_count == 0
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
assert mock_beam.call_count == 1
def test_simple_jit_reset(self):
@TinyJit
def add(a, b): return (a+b).realize()

View File

@@ -322,12 +322,11 @@ class TinyJit(Generic[ReturnType]):
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._linears: list[UOp] = []
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
if not len(self._linears): raise JitError("didn't JIT anything!")
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
@@ -344,7 +343,7 @@ class TinyJit(Generic[ReturnType]):
ei.run(var_vals, jit=True)
del onetime_linear
jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)]
with Context(BEAM=getenv("JITBEAM", BEAM.value)): jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)]
del big_linear
# track inputs that are views of buffers