diff --git a/test/test_ops.py b/test/test_ops.py index 6c8baaa203..3967bf29a1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -92,6 +92,10 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + def test_matmul(self): + helper_test_op([(65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(512), (512,802)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op([(802), (802,512)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) def test_multidot(self):