mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test case for jit a function with item call (#14039)
* test case for jit a function with item call output is silently wrong now * no dtype
This commit is contained in:
@@ -513,6 +513,24 @@ class TestJit(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
f(Tensor([2.0])).item()
|
||||
|
||||
def test_jit_item_not_supported(self):
|
||||
# .item() is not jittable because the output shape depends on runtime values
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return Tensor.zeros(x.sum().item())
|
||||
for _ in range(3): f(Tensor([1, 1, 1])) # sum=3
|
||||
result = f(Tensor([2, 2, 2])) # sum=6, but jit baked in n=3
|
||||
self.assertEqual(result.shape, (3,)) # wrong, should be (6,)
|
||||
|
||||
def test_jit_masked_select_not_supported(self):
|
||||
# masked_select uses .item() internally, so output size gets baked in
|
||||
@TinyJit
|
||||
def f(x:Tensor, mask:Tensor) -> Tensor: return x.masked_select(mask)
|
||||
mask_2 = Tensor([True, False, True, False])
|
||||
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
|
||||
mask_3 = Tensor([True, True, True, False])
|
||||
result = f(Tensor([1, 2, 3, 4]), mask_3)
|
||||
self.assertEqual(result.shape, (2,)) # wrong, should be (3,)
|
||||
|
||||
@unittest.skip("Pending multioutput implementation #3607")
|
||||
class TestMultioutputJit(unittest.TestCase):
|
||||
def _test(self, f):
|
||||
|
||||
Reference in New Issue
Block a user