diff --git a/test/test_jit.py b/test/test_jit.py index 2a28bd0201..7b0b5dc5f4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -513,24 +513,6 @@ class TestJit(unittest.TestCase): with self.assertRaises(AssertionError): f(Tensor([2.0])).item() - def test_jit_item_not_supported(self): - # .item() is not jittable because the output shape depends on runtime values - @TinyJit - def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item()) - for _ in range(3): f(Tensor([1, 1, 1])) # sum=3 - result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3 - self.assertEqual(result.shape, (3,)) # wrong, should be (6,) - - def test_jit_masked_select_not_supported(self): - # masked_select uses .item() internally, so output size gets baked in - @TinyJit - def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask) - mask_2 = Tensor([True, False, True, False]) - for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2) - mask_3 = Tensor([True, True, True, False]) - result = f(Tensor([1, 2, 3, 4]), mask_3) - self.assertEqual(result.shape, (2,)) # wrong, should be (3,) - @unittest.skip("Pending multioutput implementation #3607") class TestMultioutputJit(unittest.TestCase): def _test(self, f): diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index a0eebe377c..bf0c596de0 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -217,6 +217,35 @@ class TestJitFootguns(unittest.TestCase): for _ in range(3): f(Tensor([1]), Tensor([2])) + def test_item_creates_unrealized_return(self): + """.item() in shape computation creates unrealized return with baked-in shape.""" + @TinyJit + def f(x): return Tensor.zeros(x.sum().item()) + + for _ in range(3): f(Tensor([1, 1, 1])) # captures with sum=3 + result = f(Tensor([2, 2, 2])) # sum=6, but shape is baked in + assert result.shape == (3,) # should be (6,)! + + def test_item_bakes_in_values(self): + """.item() value is baked in, causing wrong output shapes (silent failure).""" + @TinyJit + def f(x, mask): return x.masked_select(mask) + + mask_2 = Tensor([True, False, True, False]) + for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2) + mask_3 = Tensor([True, True, True, False]) + result = f(Tensor([1, 2, 3, 4]), mask_3) + assert result.shape == (2,) # should be (3,)! + + def test_tolist_bakes_in_values(self): + """.tolist() returns Python values that get baked in (silent failure).""" + @TinyJit + def f(x): return Tensor(x.tolist()) + + for _ in range(3): f(Tensor([1, 2, 3])) + result = f(Tensor([4, 5, 6])) + np.testing.assert_equal(result.numpy(), [1, 2, 3]) # should be [4,5,6]! + class TestJitCorrectBehavior(unittest.TestCase): """Behaviors that work correctly - documented for clarity."""