raise when jit fxn returns non-Tensor output (#14042)

This commit is contained in:
chenyu
2026-01-06 12:59:20 -05:00
committed by GitHub
parent 4491ec0c9e
commit 7fb18f7e47
3 changed files with 15 additions and 28 deletions

View File

@@ -230,20 +230,9 @@ class TestJit(unittest.TestCase):
def test_jit_output_non_tensor_fail(self):
@TinyJit
def f(a, b, i): return (a+b).realize(), i
output1, output2 = [], []
expect1, expect2 = [], []
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
o1, o2 = f(a, b, i)
output1.append(o1.numpy().copy())
output2.append(o2)
expect1.append(a.numpy().copy()+b.numpy().copy())
expect2.append(i)
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert_jit_cache_len(f, 1)
with self.assertRaises(JitError):
for i in range(3):
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
def test_jit_random_regen(self):
def f(a, b):