mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
@@ -513,24 +513,6 @@ class TestJit(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
f(Tensor([2.0])).item()
|
||||
|
||||
def test_jit_item_not_supported(self):
|
||||
# .item() is not jittable because the output shape depends on runtime values
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item())
|
||||
for _ in range(3): f(Tensor([1, 1, 1])) # sum=3
|
||||
result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3
|
||||
self.assertEqual(result.shape, (3,)) # wrong, should be (6,)
|
||||
|
||||
def test_jit_masked_select_not_supported(self):
|
||||
# masked_select uses .item() internally, so output size gets baked in
|
||||
@TinyJit
|
||||
def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask)
|
||||
mask_2 = Tensor([True, False, True, False])
|
||||
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
|
||||
mask_3 = Tensor([True, True, True, False])
|
||||
result = f(Tensor([1, 2, 3, 4]), mask_3)
|
||||
self.assertEqual(result.shape, (2,)) # wrong, should be (3,)
|
||||
|
||||
@unittest.skip("Pending multioutput implementation #3607")
|
||||
class TestMultioutputJit(unittest.TestCase):
|
||||
def _test(self, f):
|
||||
|
||||
@@ -217,6 +217,35 @@ class TestJitFootguns(unittest.TestCase):
|
||||
for _ in range(3):
|
||||
f(Tensor([1]), Tensor([2]))
|
||||
|
||||
def test_item_creates_unrealized_return(self):
|
||||
""".item() in shape computation creates unrealized return with baked-in shape."""
|
||||
@TinyJit
|
||||
def f(x): return Tensor.zeros(x.sum().item())
|
||||
|
||||
for _ in range(3): f(Tensor([1, 1, 1])) # captures with sum=3
|
||||
result = f(Tensor([2, 2, 2])) # sum=6, but shape is baked in
|
||||
assert result.shape == (3,) # should be (6,)!
|
||||
|
||||
def test_item_bakes_in_values(self):
|
||||
""".item() value is baked in, causing wrong output shapes (silent failure)."""
|
||||
@TinyJit
|
||||
def f(x, mask): return x.masked_select(mask)
|
||||
|
||||
mask_2 = Tensor([True, False, True, False])
|
||||
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
|
||||
mask_3 = Tensor([True, True, True, False])
|
||||
result = f(Tensor([1, 2, 3, 4]), mask_3)
|
||||
assert result.shape == (2,) # should be (3,)!
|
||||
|
||||
def test_tolist_bakes_in_values(self):
|
||||
""".tolist() returns Python values that get baked in (silent failure)."""
|
||||
@TinyJit
|
||||
def f(x): return Tensor(x.tolist())
|
||||
|
||||
for _ in range(3): f(Tensor([1, 2, 3]))
|
||||
result = f(Tensor([4, 5, 6]))
|
||||
np.testing.assert_equal(result.numpy(), [1, 2, 3]) # should be [4,5,6]!
|
||||
|
||||
|
||||
class TestJitCorrectBehavior(unittest.TestCase):
|
||||
"""Behaviors that work correctly - documented for clarity."""
|
||||
|
||||
Reference in New Issue
Block a user