mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make jit in jit fail on first call
also clean up structure
This commit is contained in:
@@ -526,9 +526,8 @@ class TestJitInsideJit(unittest.TestCase):
|
||||
@TinyJit
|
||||
def g(t): return f(t) * 3
|
||||
|
||||
# NOTE: first does not raise
|
||||
g(Tensor([1])).realize()
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
# nested JIT raises on first call
|
||||
with self.assertRaises(JitError):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
class TestCopyInsideJit(unittest.TestCase):
|
||||
|
||||
@@ -16,7 +16,6 @@ ERRORS RAISED (lower priority - at least users know):
|
||||
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
||||
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
|
||||
duplicate_inputs_fail MED would need to handle aliasing in input_replace
|
||||
nested_jit_fails_on_second_call MED could fail on first call instead of second
|
||||
"""
|
||||
import unittest
|
||||
import numpy as np
|
||||
@@ -71,17 +70,6 @@ class TestJitFootguns(unittest.TestCase):
|
||||
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
||||
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
||||
|
||||
def test_nested_jit_fails_on_second_call(self):
|
||||
"""Nested JIT works on first call but fails on second."""
|
||||
@TinyJit
|
||||
def inner(t): return t + 1
|
||||
@TinyJit
|
||||
def outer(t): return inner(t) * 3
|
||||
|
||||
self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works!
|
||||
with self.assertRaises(RuntimeError):
|
||||
outer(Tensor([2])).realize() # fails
|
||||
|
||||
def test_implicit_inputs_need_realize(self):
|
||||
"""Closure tensors must be realized before JIT call."""
|
||||
x = Tensor([0])
|
||||
|
||||
@@ -284,24 +284,27 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
|
||||
if not JIT or self.cnt == 0:
|
||||
# jit ignore
|
||||
def run(fxn):
|
||||
assert self.fxn is not None
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
ret = fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
return ret
|
||||
if not JIT: ret = run(self.fxn)
|
||||
elif self.cnt == 0:
|
||||
# jit warmup
|
||||
if capturing: raise JitError("having TinyJit inside another TinyJit is not supported")
|
||||
capturing.append(self)
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, CAPTURING=0):
|
||||
try: ret = run(self.fxn)
|
||||
finally: capturing.clear()
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
assert self.fxn is not None
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
self._jit_cache: list[ExecItem] = []
|
||||
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
||||
# TODO: should we always disable the memory planner here? it must be off for prune
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
try: ret = run(self.fxn)
|
||||
except Exception as e: raise e
|
||||
finally: capturing.clear()
|
||||
jit_cache = self._jit_cache
|
||||
|
||||
Reference in New Issue
Block a user