mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
improve test assertions for jit cache len with graph executor (#2476)
* improve test assertions for jit cache len with graph executor * delete newline * unused import * another unused import
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from tinygrad.device import JITRunner
|
||||
from tinygrad.ops import LazyOp, LoadOps
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
@@ -13,3 +14,13 @@ def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) > 0
|
||||
if issubclass(type(fxn.jit_cache[0].prg), JITRunner):
|
||||
assert len(fxn.jit_cache) == expected_len
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1
|
||||
# until we have a better way of typing the prg in JitItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
@@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Device
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@@ -14,7 +15,7 @@ class TestJit(unittest.TestCase):
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(add.jit_cache) == 1
|
||||
assert_jit_cache_len(add, 1)
|
||||
|
||||
def test_jit_multiple_outputs(self):
|
||||
@TinyJit
|
||||
@@ -26,7 +27,7 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(f.jit_cache) == 3 or (len(f.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
assert_jit_cache_len(f, 3)
|
||||
|
||||
def test_nothing_jitted(self):
|
||||
@TinyJit
|
||||
@@ -73,7 +74,7 @@ class TestJit(unittest.TestCase):
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add_kwargs(first=a, second=b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(add_kwargs.jit_cache) == 1
|
||||
assert_jit_cache_len(add_kwargs, 1)
|
||||
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
@@ -88,7 +89,7 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
||||
else:
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(add_array.jit_cache) == 1
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_method_jit(self):
|
||||
class Fun:
|
||||
@@ -102,7 +103,7 @@ class TestJit(unittest.TestCase):
|
||||
b = Tensor.randn(10, 10)
|
||||
c = fun(b)
|
||||
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(fun.__call__.func.__self__.jit_cache) == 1
|
||||
assert_jit_cache_len(fun.__call__.func.__self__, 1)
|
||||
|
||||
def test_jit_size1_input(self):
|
||||
@TinyJit
|
||||
@@ -110,7 +111,7 @@ class TestJit(unittest.TestCase):
|
||||
a = Tensor([1, 2, 3])
|
||||
for i in range(5):
|
||||
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(f.jit_cache) == 1
|
||||
assert_jit_cache_len(f, 1)
|
||||
|
||||
def test_jit_output_non_tensor_fail(self):
|
||||
@TinyJit
|
||||
@@ -128,7 +129,7 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
|
||||
# the jit only works with Tensor outputs
|
||||
assert output2 != expect2
|
||||
assert len(f.jit_cache) == 1
|
||||
assert_jit_cache_len(f, 1)
|
||||
|
||||
@unittest.skip("random isn't working in JIT")
|
||||
def test_jit_random_regen(self):
|
||||
@@ -237,8 +238,8 @@ class TestJit(unittest.TestCase):
|
||||
# but the bad_jitted doesn't!
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
assert len(cache.good_jitted.jit_cache) == 1
|
||||
assert len(cache.bad_jitted.jit_cache) == 1
|
||||
assert_jit_cache_len(cache.good_jitted, 1)
|
||||
assert_jit_cache_len(cache.bad_jitted, 1)
|
||||
|
||||
def test_jit_buffer_behavior(self):
|
||||
@TinyJit
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import unittest
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
@@ -16,7 +18,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
@@ -28,7 +30,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
@@ -40,7 +42,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_mixed_with_no_symbol_kernel(self):
|
||||
def f(a, b):
|
||||
@@ -55,7 +57,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 2 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
assert_jit_cache_len(jf, 2)
|
||||
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
@@ -68,7 +70,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k, v).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 6 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
assert_jit_cache_len(jf, 6)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
@@ -80,7 +82,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
@@ -92,7 +94,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
@@ -106,7 +108,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
@@ -120,7 +122,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_two_vars_plus1_ij(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
@@ -134,7 +136,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_two_vars_plus1_ji(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
@@ -148,7 +150,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_jit_symbolic_shape_mismatch(self):
|
||||
@TinyJit
|
||||
@@ -175,7 +177,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user