mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -39,6 +39,18 @@ class TestJit(unittest.TestCase):
|
||||
def add(a, b): return (a+b).realize()
|
||||
_simple_test(add)
|
||||
|
||||
def test_jitbeam_triggers_beam(self):
|
||||
from unittest.mock import patch
|
||||
from tinygrad.helpers import getenv as _getenv
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
|
||||
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
|
||||
add(a, b)
|
||||
assert mock_beam.call_count == 0
|
||||
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
|
||||
assert mock_beam.call_count == 1
|
||||
|
||||
def test_simple_jit_reset(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
|
||||
@@ -322,12 +322,11 @@ class TinyJit(Generic[ReturnType]):
|
||||
assert self.fxn is not None
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
self._linears: list[UOp] = []
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
if not len(self._linears): raise JitError("didn't JIT anything!")
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
|
||||
@@ -344,7 +343,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
ei.run(var_vals, jit=True)
|
||||
del onetime_linear
|
||||
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)]
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)): jit_cache = [ei.lower() for ei in linear_to_schedule(big_linear)]
|
||||
del big_linear
|
||||
|
||||
# track inputs that are views of buffers
|
||||
|
||||
Reference in New Issue
Block a user