mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.jit import TinyJit
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
|
|
class TestJit(unittest.TestCase):
|
|
def test_simple_jit(self):
|
|
@TinyJit
|
|
def add(a, b): return (a+b).realize()
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 10)
|
|
b = Tensor.randn(10, 10)
|
|
c = add(a, b)
|
|
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
|
|
|
|
def test_jit_shape_mismatch(self):
|
|
@TinyJit
|
|
def add(a, b): return (a+b).realize()
|
|
for _ in range(3):
|
|
a = Tensor.randn(10, 10)
|
|
b = Tensor.randn(10, 10)
|
|
c = add(a, b)
|
|
bad = Tensor.randn(20, 20)
|
|
with self.assertRaises(AssertionError):
|
|
add(a, bad)
|
|
|
|
def test_jit_duplicate_fail(self):
|
|
# the jit doesn't support duplicate arguments
|
|
@TinyJit
|
|
def add(a, b): return (a+b).realize()
|
|
a = Tensor.randn(10, 10)
|
|
with self.assertRaises(AssertionError):
|
|
add(a, a)
|
|
|
|
def test_kwargs_jit(self):
|
|
@TinyJit
|
|
def add_kwargs(first, second): return (first+second).realize()
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 10)
|
|
b = Tensor.randn(10, 10)
|
|
c = add_kwargs(first=a, second=b)
|
|
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
|
|
|
|
def test_array_jit(self):
|
|
@TinyJit
|
|
def add_array(a, arr): return (a+arr[0]).realize()
|
|
for i in range(5):
|
|
a = Tensor.randn(10, 10)
|
|
b = Tensor.randn(10, 10)
|
|
a.realize(), b.realize()
|
|
c = add_array(a, [b])
|
|
if i >= 2:
|
|
# should fail once jitted since jit can't handle arrays
|
|
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)
|
|
else:
|
|
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
|
|
|
|
def test_method_jit(self):
|
|
class Fun:
|
|
def __init__(self):
|
|
self.a = Tensor.randn(10, 10)
|
|
@TinyJit
|
|
def __call__(self, b:Tensor) -> Tensor:
|
|
return (self.a+b).realize()
|
|
fun = Fun()
|
|
for _ in range(5):
|
|
b = Tensor.randn(10, 10)
|
|
c = fun(b)
|
|
np.testing.assert_equal(c.numpy(), fun.a.numpy()+b.numpy())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |