jit includes tensor inputs in containers (#14043)

* jit includes tensor inputs in containers

* cleanup
This commit is contained in:
chenyu
2026-01-06 19:42:06 -05:00
committed by GitHub
parent c714881832
commit 72a3f78d19
3 changed files with 9 additions and 21 deletions

View File

@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
def test_array_jit(self): def test_array_jit(self):
@TinyJit @TinyJit
def add_array(a, arr): return (a+arr[0]).realize() def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5): for _ in range(5):
a = Tensor.randn(10, 10) a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
b = Tensor.randn(10, 10) np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1) assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self): def test_jit_copyin(self):

View File

@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
Comments marked "should be X!" indicate the intuitively expected value. Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error): SILENT MISMATCHES (highest priority - wrong results, no error):
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
class_method_shared_across_instances EASY could check if first arg is self and warn class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs python_constants_frozen HARD inherent to tracing JITs
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
with self.assertRaises(JitError): with self.assertRaises(JitError):
f(x, x) f(x, x)
def test_tensors_in_containers_ignored(self): def test_tensors_in_containers(self):
"""Tensors inside lists/dicts are not tracked as inputs."""
@TinyJit @TinyJit
def f(a, arr): return (a + arr[0]).realize() def f(a, arr): return (a + arr[0]).realize()
results = []
for i in range(4): for i in range(4):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize() a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
results.append(f(a, [b]).numpy().copy()) np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
def test_nested_jit_fails_on_second_call(self): def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second.""" """Nested JIT works on first call but fails on second."""

View File

@@ -228,6 +228,10 @@ class CapturedJit(Generic[ReturnType]):
def _prepare_jit_inputs(args, kwargs): def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
# extract tensors from containers (shallow, not recursive to avoid grabbing model weights)
for x in args + tuple(kwargs.values()):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested. # TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])