test copy inside jit [pr] (#8072)

This commit is contained in:
George Hotz
2024-12-06 17:51:50 +08:00
committed by GitHub
parent e2fe7f0d2f
commit aae8557ada

View File

@@ -483,5 +483,16 @@ class TestJitInsideJit(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
g(Tensor([1])).realize()
class TestCopyInsideJit(unittest.TestCase):
def test_copy_inside_jit(self):
@TinyJit
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
for _ in range(5):
# create a Tensor in CLANG
a = Tensor.rand(16,16,device="CLANG").realize()
b = Tensor.rand(16,16).realize()
out = add(a,b)
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
if __name__ == '__main__':
unittest.main()