mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix: JIT=0 means no JIT (#5188)
This commit is contained in:
@@ -15,12 +15,14 @@ def derandomize_model(model):
|
||||
p.realize()
|
||||
|
||||
def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) > 0
|
||||
if not fxn.jit_cache:
|
||||
assert expected_len == 0, expected_len
|
||||
return
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
||||
assert len(fxn.jit_cache) == expected_len
|
||||
assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache)
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1
|
||||
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
@@ -6,7 +6,7 @@ from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
def _simple_test(add, extract=lambda x: x, N=10):
|
||||
@@ -66,6 +66,17 @@ class TestJit(unittest.TestCase):
|
||||
b = Tensor.randn(10, 10)
|
||||
add(a, b)
|
||||
|
||||
def test_jit_zero_does_not_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
with Context(JIT=0):
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 0)
|
||||
|
||||
def test_jit_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
|
||||
@@ -144,7 +144,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
||||
if self.cnt == 0:
|
||||
if not JIT or self.cnt == 0:
|
||||
# jit ignore
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user