diff --git a/test/test_ops.py b/test/test_ops.py index 4c4589876a..e79d071502 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1128,9 +1128,12 @@ class TestOps(unittest.TestCase): def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) - def test_matvec(self): + def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4) + def test_matvec(self): + helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu(), atol=1e-4) + # this was the failure in llama early realizing freqs_cis def test_double_slice(self): helper_test_op([(4,4)], lambda x: x[:, 1:2][1:2])