test jit tolist failure (#14040)

also moved tests to test_jit_footguns
This commit is contained in:
chenyu
2026-01-06 11:16:57 -05:00
committed by George Hotz
parent 02ab3eb153
commit b4fd0954b7
2 changed files with 29 additions and 18 deletions

View File

@@ -513,24 +513,6 @@ class TestJit(unittest.TestCase):
with self.assertRaises(AssertionError):
f(Tensor([2.0])).item()
def test_jit_item_not_supported(self):
# .item() is not jittable because the output shape depends on runtime values
@TinyJit
def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item())
for _ in range(3): f(Tensor([1, 1, 1])) # sum=3
result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3
self.assertEqual(result.shape, (3,)) # wrong, should be (6,)
def test_jit_masked_select_not_supported(self):
# masked_select uses .item() internally, so output size gets baked in
@TinyJit
def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask)
mask_2 = Tensor([True, False, True, False])
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
mask_3 = Tensor([True, True, True, False])
result = f(Tensor([1, 2, 3, 4]), mask_3)
self.assertEqual(result.shape, (2,)) # wrong, should be (3,)
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):
def _test(self, f):

View File

@@ -217,6 +217,35 @@ class TestJitFootguns(unittest.TestCase):
for _ in range(3):
f(Tensor([1]), Tensor([2]))
def test_item_creates_unrealized_return(self):
""".item() in shape computation creates unrealized return with baked-in shape."""
@TinyJit
def f(x): return Tensor.zeros(x.sum().item())
for _ in range(3): f(Tensor([1, 1, 1])) # captures with sum=3
result = f(Tensor([2, 2, 2])) # sum=6, but shape is baked in
assert result.shape == (3,) # should be (6,)!
def test_item_bakes_in_values(self):
""".item() value is baked in, causing wrong output shapes (silent failure)."""
@TinyJit
def f(x, mask): return x.masked_select(mask)
mask_2 = Tensor([True, False, True, False])
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
mask_3 = Tensor([True, True, True, False])
result = f(Tensor([1, 2, 3, 4]), mask_3)
assert result.shape == (2,) # should be (3,)!
def test_tolist_bakes_in_values(self):
""".tolist() returns Python values that get baked in (silent failure)."""
@TinyJit
def f(x): return Tensor(x.tolist())
for _ in range(3): f(Tensor([1, 2, 3]))
result = f(Tensor([4, 5, 6]))
np.testing.assert_equal(result.numpy(), [1, 2, 3]) # should be [4,5,6]!
class TestJitCorrectBehavior(unittest.TestCase):
"""Behaviors that work correctly - documented for clarity."""