improve test assertions for jit cache len with graph executor (#2476)

* improve test assertions for jit cache len with graph executor

* delete newline

* unused import

* another unused import
This commit is contained in:
mmmkkaaayy
2023-11-27 23:02:45 -08:00
committed by GitHub
parent 28a67106ca
commit ddb6a33ae5
4 changed files with 38 additions and 24 deletions

View File

@@ -1,3 +1,4 @@
from tinygrad.device import JITRunner
from tinygrad.ops import LazyOp, LoadOps
from tinygrad.nn.state import get_parameters
@@ -13,3 +14,13 @@ def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()
def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
if issubclass(type(fxn.jit_cache[0].prg), JITRunner):
assert len(fxn.jit_cache) == expected_len
else:
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

View File

@@ -1,7 +1,8 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad import Device
from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
@@ -14,7 +15,7 @@ class TestJit(unittest.TestCase):
b = Tensor.randn(10, 10)
c = add(a, b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add.jit_cache) == 1
assert_jit_cache_len(add, 1)
def test_jit_multiple_outputs(self):
@TinyJit
@@ -26,7 +27,7 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
assert len(f.jit_cache) == 3 or (len(f.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
assert_jit_cache_len(f, 3)
def test_nothing_jitted(self):
@TinyJit
@@ -73,7 +74,7 @@ class TestJit(unittest.TestCase):
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add_kwargs.jit_cache) == 1
assert_jit_cache_len(add_kwargs, 1)
def test_array_jit(self):
@TinyJit
@@ -88,7 +89,7 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add_array.jit_cache) == 1
assert_jit_cache_len(add_array, 1)
def test_method_jit(self):
class Fun:
@@ -102,7 +103,7 @@ class TestJit(unittest.TestCase):
b = Tensor.randn(10, 10)
c = fun(b)
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(fun.__call__.func.__self__.jit_cache) == 1
assert_jit_cache_len(fun.__call__.func.__self__, 1)
def test_jit_size1_input(self):
@TinyJit
@@ -110,7 +111,7 @@ class TestJit(unittest.TestCase):
a = Tensor([1, 2, 3])
for i in range(5):
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
assert len(f.jit_cache) == 1
assert_jit_cache_len(f, 1)
def test_jit_output_non_tensor_fail(self):
@TinyJit
@@ -128,7 +129,7 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert len(f.jit_cache) == 1
assert_jit_cache_len(f, 1)
@unittest.skip("random isn't working in JIT")
def test_jit_random_regen(self):
@@ -237,8 +238,8 @@ class TestJit(unittest.TestCase):
# but the bad_jitted doesn't!
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
assert len(cache.good_jitted.jit_cache) == 1
assert len(cache.bad_jitted.jit_cache) == 1
assert_jit_cache_len(cache.good_jitted, 1)
assert_jit_cache_len(cache.bad_jitted, 1)
def test_jit_buffer_behavior(self):
@TinyJit

View File

@@ -1,8 +1,10 @@
import unittest
from test.helpers import assert_jit_cache_len
from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor, Device
from tinygrad.tensor import Tensor
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@@ -16,7 +18,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_add(self):
def f(a, b): return (a+b).realize()
@@ -28,7 +30,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_matmul(self):
def f(a, b): return (a@b).realize()
@@ -40,7 +42,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_mixed_with_no_symbol_kernel(self):
def f(a, b):
@@ -55,7 +57,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
assert_jit_cache_len(jf, 2)
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
@@ -68,7 +70,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
assert_jit_cache_len(jf, 6)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
@@ -80,7 +82,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
@@ -92,7 +94,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
@@ -106,7 +108,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
@@ -120,7 +122,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
@@ -134,7 +136,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_two_vars_plus1_ji(self):
def f(a, b): return (a@b+1).realize()
@@ -148,7 +150,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
def test_jit_symbolic_shape_mismatch(self):
@TinyJit
@@ -175,7 +177,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
assert_jit_cache_len(jf, 1)
if __name__ == '__main__':
unittest.main()