From 7fb18f7e471e1c3458c94cb9dda048630b684488 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 6 Jan 2026 12:59:20 -0500 Subject: [PATCH] raise when jit fxn returns non-Tensor output (#14042) --- test/test_jit.py | 17 +++-------------- test/test_jit_footguns.py | 18 ++++-------------- tinygrad/engine/jit.py | 8 ++++++++ 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 816117f69e..265cee5fdd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -230,20 +230,9 @@ class TestJit(unittest.TestCase): def test_jit_output_non_tensor_fail(self): @TinyJit def f(a, b, i): return (a+b).realize(), i - output1, output2 = [], [] - expect1, expect2 = [], [] - for i in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - o1, o2 = f(a, b, i) - output1.append(o1.numpy().copy()) - output2.append(o2) - expect1.append(a.numpy().copy()+b.numpy().copy()) - expect2.append(i) - np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5) - # the jit only works with Tensor outputs - assert output2 != expect2 - assert_jit_cache_len(f, 1) + with self.assertRaises(JitError): + for i in range(3): + f(Tensor.randn(10, 10), Tensor.randn(10, 10), i) def test_jit_random_regen(self): def f(a, b): diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 8a8506b827..9b50b03a11 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -7,13 +7,13 @@ Comments marked "should be X!" indicate the intuitively expected value. SILENT MISMATCHES (highest priority - wrong results, no error): tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts - non_tensor_outputs_frozen EASY could warn/error if return contains non-Tensor values class_method_shared_across_instances EASY could check if first arg is self and warn output_buffer_reuse MED performance tradeoff, could add option or better docs python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs ERRORS RAISED (lower priority - at least users know): + non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature duplicate_inputs_fail MED would need to handle aliasing in input_replace nested_jit_fails_on_second_call MED could fail on first call instead of second @@ -49,21 +49,11 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6]) - def test_non_tensor_outputs_frozen(self): - """Non-tensor return values are frozen at capture time.""" + def test_non_tensor_outputs_error(self): @TinyJit def f(x, mult): return (x * 2).realize(), mult * 10 - - # collect results, copying tensor values immediately (buffer reuse!) - results = [] - for i in range(5): - t, s = f(Tensor([i]), i) - results.append((t.item(), s)) - - # tensor outputs work correctly - self.assertEqual([r[0] for r in results[2:]], [4, 6, 8]) - # scalar outputs frozen at capture (i=1) - should be 20, 30, 40! - self.assertEqual([r[1] for r in results[2:]], [10, 10, 10]) + with self.assertRaises(JitError): + for i in range(3): f(Tensor([i]), i) def test_duplicate_inputs_fail(self): """JIT cannot handle the same tensor passed as multiple arguments.""" diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 46a5517eec..8c2363903e 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -15,6 +15,13 @@ from weakref import WeakKeyDictionary class GraphException(Exception): pass class JitError(Exception): pass +def _check_no_non_tensor_return(ret): + if ret is None or isinstance(ret, Tensor): return + if isinstance(ret, (tuple, list, dict)): + for item in (ret.values() if isinstance(ret, dict) else ret): _check_no_non_tensor_return(item) + return + raise JitError(f"JIT return contains non-Tensor value of type {type(ret).__name__}") + def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]: @@ -296,6 +303,7 @@ class TinyJit(Generic[ReturnType]): jit_cache = self._jit_cache del self._buffer_replace, self._jit_cache if not len(jit_cache): raise JitError("didn't JIT anything!") + _check_no_non_tensor_return(ret) if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs") # track inputs that are views of buffers