From 02ab3eb153ab7b17448b2f23b2bb474cf7f53374 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 6 Jan 2026 10:40:43 -0500 Subject: [PATCH] test case for jit a function with item call (#14039) * test case for jit a function with item call output is silently wrong now * no dtype --- test/test_jit.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index 7b0b5dc5f4..2a28bd0201 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -513,6 +513,24 @@ class TestJit(unittest.TestCase): with self.assertRaises(AssertionError): f(Tensor([2.0])).item() + def test_jit_item_not_supported(self): + # .item() is not jittable because the output shape depends on runtime values + @TinyJit + def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item()) + for _ in range(3): f(Tensor([1, 1, 1])) # sum=3 + result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3 + self.assertEqual(result.shape, (3,)) # wrong, should be (6,) + + def test_jit_masked_select_not_supported(self): + # masked_select uses .item() internally, so output size gets baked in + @TinyJit + def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask) + mask_2 = Tensor([True, False, True, False]) + for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2) + mask_3 = Tensor([True, True, True, False]) + result = f(Tensor([1, 2, 3, 4]), mask_3) + self.assertEqual(result.shape, (2,)) # wrong, should be (3,) + @unittest.skip("Pending multioutput implementation #3607") class TestMultioutputJit(unittest.TestCase): def _test(self, f):