that had bugs, force an order (#2411)

This commit is contained in:
George Hotz
2023-11-23 15:52:16 -08:00
committed by GitHub
parent 65f4e6971b
commit 193be14b6c
5 changed files with 11 additions and 14 deletions

View File

@@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
@@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()