raise when jit fxn returns non-Tensor output (#14042)

This commit is contained in:
chenyu
2026-01-06 12:59:20 -05:00
committed by George Hotz
parent aa96d826f4
commit caa52dcbe5
3 changed files with 15 additions and 28 deletions

View File

@@ -230,20 +230,9 @@ class TestJit(unittest.TestCase):
def test_jit_output_non_tensor_fail(self):
@TinyJit
def f(a, b, i): return (a+b).realize(), i
output1, output2 = [], []
expect1, expect2 = [], []
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
o1, o2 = f(a, b, i)
output1.append(o1.numpy().copy())
output2.append(o2)
expect1.append(a.numpy().copy()+b.numpy().copy())
expect2.append(i)
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert_jit_cache_len(f, 1)
with self.assertRaises(JitError):
for i in range(3):
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
def test_jit_random_regen(self):
def f(a, b):

View File

@@ -7,13 +7,13 @@ Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
non_tensor_outputs_frozen EASY could warn/error if return contains non-Tensor values
class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
ERRORS RAISED (lower priority - at least users know):
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
duplicate_inputs_fail MED would need to handle aliasing in input_replace
nested_jit_fails_on_second_call MED could fail on first call instead of second
@@ -49,21 +49,11 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
def test_non_tensor_outputs_frozen(self):
"""Non-tensor return values are frozen at capture time."""
def test_non_tensor_outputs_error(self):
@TinyJit
def f(x, mult): return (x * 2).realize(), mult * 10
# collect results, copying tensor values immediately (buffer reuse!)
results = []
for i in range(5):
t, s = f(Tensor([i]), i)
results.append((t.item(), s))
# tensor outputs work correctly
self.assertEqual([r[0] for r in results[2:]], [4, 6, 8])
# scalar outputs frozen at capture (i=1) - should be 20, 30, 40!
self.assertEqual([r[1] for r in results[2:]], [10, 10, 10])
with self.assertRaises(JitError):
for i in range(3): f(Tensor([i]), i)
def test_duplicate_inputs_fail(self):
"""JIT cannot handle the same tensor passed as multiple arguments."""

View File

@@ -15,6 +15,13 @@ from weakref import WeakKeyDictionary
class GraphException(Exception): pass
class JitError(Exception): pass
def _check_no_non_tensor_return(ret):
if ret is None or isinstance(ret, Tensor): return
if isinstance(ret, (tuple, list, dict)):
for item in (ret.values() if isinstance(ret, dict) else ret): _check_no_non_tensor_return(item)
return
raise JitError(f"JIT return contains non-Tensor value of type {type(ret).__name__}")
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
@@ -296,6 +303,7 @@ class TinyJit(Generic[ReturnType]):
jit_cache = self._jit_cache
del self._buffer_replace, self._jit_cache
if not len(jit_cache): raise JitError("didn't JIT anything!")
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
# track inputs that are views of buffers