test case for jit a function with item call (#14039)

* test case for jit a function with item call

output is silently wrong now

* no dtype
This commit is contained in:
chenyu
2026-01-06 10:40:43 -05:00
committed by GitHub
parent 02084f5376
commit b699b9f763

View File

@@ -513,6 +513,24 @@ class TestJit(unittest.TestCase):
with self.assertRaises(AssertionError):
f(Tensor([2.0])).item()
def test_jit_item_not_supported(self):
# .item() is not jittable because the output shape depends on runtime values
@TinyJit
def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item())
for _ in range(3): f(Tensor([1, 1, 1])) # sum=3
result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3
self.assertEqual(result.shape, (3,)) # wrong, should be (6,)
def test_jit_masked_select_not_supported(self):
# masked_select uses .item() internally, so output size gets baked in
@TinyJit
def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask)
mask_2 = Tensor([True, False, True, False])
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
mask_3 = Tensor([True, True, True, False])
result = f(Tensor([1, 2, 3, 4]), mask_3)
self.assertEqual(result.shape, (2,)) # wrong, should be (3,)
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):
def _test(self, f):