mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test copy inside jit [pr] (#8072)
This commit is contained in:
@@ -483,5 +483,16 @@ class TestJitInsideJit(unittest.TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
class TestCopyInsideJit(unittest.TestCase):
|
||||
def test_copy_inside_jit(self):
|
||||
@TinyJit
|
||||
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
|
||||
for _ in range(5):
|
||||
# create a Tensor in CLANG
|
||||
a = Tensor.rand(16,16,device="CLANG").realize()
|
||||
b = Tensor.rand(16,16).realize()
|
||||
out = add(a,b)
|
||||
self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user