From 72a3f78d1976708d24c0535824d5d319b6f9a8cc Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 6 Jan 2026 19:42:06 -0500 Subject: [PATCH] jit includes tensor inputs in containers (#14043) * jit includes tensor inputs in containers * cleanup --- test/test_jit.py | 13 +++---------- test/test_jit_footguns.py | 13 ++----------- tinygrad/engine/jit.py | 4 ++++ 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 7720d36dfc..c53651d5d1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -184,16 +184,9 @@ class TestJit(unittest.TestCase): def test_array_jit(self): @TinyJit def add_array(a, arr): return (a+arr[0]).realize() - for i in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - a.realize(), b.realize() - c = add_array(a, [b]) - if i >= 2: - # should fail once jitted since jit can't handle arrays - np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5) - else: - np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) + for _ in range(5): + a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize() + np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert_jit_cache_len(add_array, 1) def test_jit_copyin(self): diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index f54838e54a..704653d143 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT. Comments marked "should be X!" indicate the intuitively expected value. SILENT MISMATCHES (highest priority - wrong results, no error): - tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts class_method_shared_across_instances EASY could check if first arg is self and warn output_buffer_reuse MED performance tradeoff, could add option or better docs python_constants_frozen HARD inherent to tracing JITs @@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase): with self.assertRaises(JitError): f(x, x) - def test_tensors_in_containers_ignored(self): - """Tensors inside lists/dicts are not tracked as inputs.""" + def test_tensors_in_containers(self): @TinyJit def f(a, arr): return (a + arr[0]).realize() - - results = [] for i in range(4): a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize() - results.append(f(a, [b]).numpy().copy()) - - np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup - np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture - np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]! - np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]! + np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i]) def test_nested_jit_fails_on_second_call(self): """Nested JIT works on first call but fails on second.""" diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 00906c61f5..9dbbdf6a34 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -228,6 +228,10 @@ class CapturedJit(Generic[ReturnType]): def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] + # extract tensors from containers (shallow, not recursive to avoid grabbing model weights) + for x in args + tuple(kwargs.values()): + it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else [] + tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)] if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) # TODO: this multi unpack stuff is not well tested. lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])