jit includes tensor inputs in containers (#14043)

* jit includes tensor inputs in containers

* cleanup
This commit is contained in:
chenyu
2026-01-06 19:42:06 -05:00
committed by GitHub
parent c714881832
commit 72a3f78d19
3 changed files with 9 additions and 21 deletions

View File

@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
def test_array_jit(self):
@TinyJit
def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
for _ in range(5):
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self):

View File

@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
with self.assertRaises(JitError):
f(x, x)
def test_tensors_in_containers_ignored(self):
"""Tensors inside lists/dicts are not tracked as inputs."""
def test_tensors_in_containers(self):
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
results = []
for i in range(4):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
results.append(f(a, [b]).numpy().copy())
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second."""

View File

@@ -228,6 +228,10 @@ class CapturedJit(Generic[ReturnType]):
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
# extract tensors from containers (shallow, not recursive to avoid grabbing model weights)
for x in args + tuple(kwargs.values()):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])