mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
better error msg for TinyJit inside TinyJit (#5202)
it's possible to support TinyJit inside TinyJit, but there are edge cases like two TinyJit functions shared another TinyJit function. so just give a more precise error for now
This commit is contained in:
@@ -384,5 +384,16 @@ class TestMultioutputJit(unittest.TestCase):
|
||||
self._test(fxn)
|
||||
assert_jit_cache_len(fxn, 2)
|
||||
|
||||
class TestJitInsideJit(unittest.TestCase):
|
||||
def test_jit_jit_error(self):
|
||||
@TinyJit
|
||||
def f(t): return t + 1
|
||||
|
||||
@TinyJit
|
||||
def g(t): return f(t) * 3
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
|
||||
import functools, itertools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -106,6 +106,7 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
IN_JIT = ContextVar('IN_JIT', 0)
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
self.fxn = fxn
|
||||
@@ -145,8 +146,9 @@ class TinyJit(Generic[ReturnType]):
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
||||
if not JIT or self.cnt == 0:
|
||||
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
|
||||
# jit ignore
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
elif self.cnt == 1:
|
||||
|
||||
Reference in New Issue
Block a user