Files
tinygrad/test/test_jit.py
2023-02-11 10:10:02 -08:00

42 lines
1.3 KiB
Python

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from extra.jit import TinyJit
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit
def add(a, b): return a+b
for _ in range(3):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add(a, b)
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
def test_kwargs_jit(self):
@TinyJit
def add_kwargs(first, second): return first+second
for _ in range(3):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
def test_array_jit(self):
@TinyJit
def add_array(arr): return arr[0]+arr[1]
for i in range(3):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array([a,b])
if i == 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)
else:
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
if __name__ == '__main__':
unittest.main()