diff --git a/test/test_jit.py b/test/test_jit.py index 7b0b5dc5f4..2a28bd0201 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -513,6 +513,24 @@ class TestJit(unittest.TestCase): with self.assertRaises(AssertionError): f(Tensor([2.0])).item() + def test_jit_item_not_supported(self): + # .item() is not jittable because the output shape depends on runtime values + @TinyJit + def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item()) + for _ in range(3): f(Tensor([1, 1, 1])) # sum=3 + result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3 + self.assertEqual(result.shape, (3,)) # wrong, should be (6,) + + def test_jit_masked_select_not_supported(self): + # masked_select uses .item() internally, so output size gets baked in + @TinyJit + def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask) + mask_2 = Tensor([True, False, True, False]) + for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2) + mask_3 = Tensor([True, True, True, False]) + result = f(Tensor([1, 2, 3, 4]), mask_3) + self.assertEqual(result.shape, (2,)) # wrong, should be (3,) + @unittest.skip("Pending multioutput implementation #3607") class TestMultioutputJit(unittest.TestCase): def _test(self, f):