diff --git a/test/test_jit.py b/test/test_jit.py index 744072981c..5ff0bb1b2b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -483,5 +483,16 @@ class TestJitInsideJit(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"): g(Tensor([1])).realize() +class TestCopyInsideJit(unittest.TestCase): + def test_copy_inside_jit(self): + @TinyJit + def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y + for _ in range(5): + # create a Tensor in CLANG + a = Tensor.rand(16,16,device="CLANG").realize() + b = Tensor.rand(16,16).realize() + out = add(a,b) + self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())]) + if __name__ == '__main__': unittest.main()