mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
raise when jit fxn returns non-Tensor output (#14042)
This commit is contained in:
@@ -230,20 +230,9 @@ class TestJit(unittest.TestCase):
|
||||
def test_jit_output_non_tensor_fail(self):
|
||||
@TinyJit
|
||||
def f(a, b, i): return (a+b).realize(), i
|
||||
output1, output2 = [], []
|
||||
expect1, expect2 = [], []
|
||||
for i in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
o1, o2 = f(a, b, i)
|
||||
output1.append(o1.numpy().copy())
|
||||
output2.append(o2)
|
||||
expect1.append(a.numpy().copy()+b.numpy().copy())
|
||||
expect2.append(i)
|
||||
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
|
||||
# the jit only works with Tensor outputs
|
||||
assert output2 != expect2
|
||||
assert_jit_cache_len(f, 1)
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(3):
|
||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
|
||||
|
||||
def test_jit_random_regen(self):
|
||||
def f(a, b):
|
||||
|
||||
@@ -7,13 +7,13 @@ Comments marked "should be X!" indicate the intuitively expected value.
|
||||
|
||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
|
||||
non_tensor_outputs_frozen EASY could warn/error if return contains non-Tensor values
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
|
||||
ERRORS RAISED (lower priority - at least users know):
|
||||
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
||||
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
|
||||
duplicate_inputs_fail MED would need to handle aliasing in input_replace
|
||||
nested_jit_fails_on_second_call MED could fail on first call instead of second
|
||||
@@ -49,21 +49,11 @@ class TestJitFootguns(unittest.TestCase):
|
||||
|
||||
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
|
||||
|
||||
def test_non_tensor_outputs_frozen(self):
|
||||
"""Non-tensor return values are frozen at capture time."""
|
||||
def test_non_tensor_outputs_error(self):
|
||||
@TinyJit
|
||||
def f(x, mult): return (x * 2).realize(), mult * 10
|
||||
|
||||
# collect results, copying tensor values immediately (buffer reuse!)
|
||||
results = []
|
||||
for i in range(5):
|
||||
t, s = f(Tensor([i]), i)
|
||||
results.append((t.item(), s))
|
||||
|
||||
# tensor outputs work correctly
|
||||
self.assertEqual([r[0] for r in results[2:]], [4, 6, 8])
|
||||
# scalar outputs frozen at capture (i=1) - should be 20, 30, 40!
|
||||
self.assertEqual([r[1] for r in results[2:]], [10, 10, 10])
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(3): f(Tensor([i]), i)
|
||||
|
||||
def test_duplicate_inputs_fail(self):
|
||||
"""JIT cannot handle the same tensor passed as multiple arguments."""
|
||||
|
||||
@@ -15,6 +15,13 @@ from weakref import WeakKeyDictionary
|
||||
class GraphException(Exception): pass
|
||||
class JitError(Exception): pass
|
||||
|
||||
def _check_no_non_tensor_return(ret):
|
||||
if ret is None or isinstance(ret, Tensor): return
|
||||
if isinstance(ret, (tuple, list, dict)):
|
||||
for item in (ret.values() if isinstance(ret, dict) else ret): _check_no_non_tensor_return(item)
|
||||
return
|
||||
raise JitError(f"JIT return contains non-Tensor value of type {type(ret).__name__}")
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
@@ -296,6 +303,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
jit_cache = self._jit_cache
|
||||
del self._buffer_replace, self._jit_cache
|
||||
if not len(jit_cache): raise JitError("didn't JIT anything!")
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
||||
|
||||
# track inputs that are views of buffers
|
||||
|
||||
Reference in New Issue
Block a user