jit includes tensor inputs in containers (#14043)

* jit includes tensor inputs in containers

* cleanup
This commit is contained in:
chenyu
2026-01-06 19:42:06 -05:00
committed by GitHub
parent c714881832
commit 72a3f78d19
3 changed files with 9 additions and 21 deletions

View File

@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
def test_array_jit(self):
@TinyJit
def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
for _ in range(5):
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self):

View File

@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
with self.assertRaises(JitError):
f(x, x)
def test_tensors_in_containers_ignored(self):
"""Tensors inside lists/dicts are not tracked as inputs."""
def test_tensors_in_containers(self):
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
results = []
for i in range(4):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
results.append(f(a, [b]).numpy().copy())
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second."""