jit includes tensor inputs in containers (#14043)

* jit includes tensor inputs in containers

* cleanup
This commit is contained in:
chenyu
2026-01-06 19:42:06 -05:00
committed by GitHub
parent c714881832
commit 72a3f78d19
3 changed files with 9 additions and 21 deletions

View File

@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
def test_array_jit(self):
@TinyJit
def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
for _ in range(5):
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self):