mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
that had bugs, force an order (#2411)
This commit is contained in:
@@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 2
|
||||
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)
|
||||
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
@@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k, v).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 6
|
||||
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
|
||||
Reference in New Issue
Block a user