fix: JIT=0 means no JIT (#5188)

This commit is contained in:
chenyu
2024-06-27 10:31:37 -04:00
committed by GitHub
parent 3af17849bf
commit 5b8fda3c65
3 changed files with 18 additions and 5 deletions

View File

@@ -15,12 +15,14 @@ def derandomize_model(model):
p.realize()
def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
if not fxn.jit_cache:
assert expected_len == 0, expected_len
return
# until we have a better way of typing the prg in ExecItem
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
assert len(fxn.jit_cache) == expected_len
assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache)
else:
assert len(fxn.jit_cache) == 1
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
# until we have a better way of typing the prg in ExecItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

View File

@@ -6,7 +6,7 @@ from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI
from tinygrad.helpers import CI, Context
from tinygrad.dtype import dtypes
def _simple_test(add, extract=lambda x: x, N=10):
@@ -66,6 +66,17 @@ class TestJit(unittest.TestCase):
b = Tensor.randn(10, 10)
add(a, b)
def test_jit_zero_does_not_jit(self):
@TinyJit
def add(a, b): return (a+b).realize()
with Context(JIT=0):
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add, 0)
def test_jit_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()

View File

@@ -144,7 +144,7 @@ class TinyJit(Generic[ReturnType]):
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
if self.cnt == 0:
if not JIT or self.cnt == 0:
# jit ignore
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
self.ret = self.fxn(*args, **kwargs)