better error msg for TinyJit inside TinyJit (#5202)

it's possible to support TinyJit inside TinyJit, but there are edge cases like two TinyJit functions shared another TinyJit function. so just give a more precise error for now
This commit is contained in:
chenyu
2024-06-27 18:09:19 -04:00
committed by GitHub
parent ac748cccdb
commit 73395b998b
2 changed files with 15 additions and 2 deletions

View File

@@ -384,5 +384,16 @@ class TestMultioutputJit(unittest.TestCase):
self._test(fxn)
assert_jit_cache_len(fxn, 2)
class TestJitInsideJit(unittest.TestCase):
def test_jit_jit_error(self):
@TinyJit
def f(t): return t + 1
@TinyJit
def g(t): return f(t) * 3
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
g(Tensor([1])).realize()
if __name__ == '__main__':
unittest.main()

View File

@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -106,6 +106,7 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
return list({id(x):x for x in wait_nodes}.values())
ReturnType = TypeVar('ReturnType')
IN_JIT = ContextVar('IN_JIT', 0)
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
@@ -145,8 +146,9 @@ class TinyJit(Generic[ReturnType]):
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
if not JIT or self.cnt == 0:
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
# jit ignore
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1: