feat: mimic the exact numpy behavior for matmul

This commit is contained in:
Umut
2022-02-28 15:06:04 +03:00
parent ed28639c57
commit b71cbc8ecb
8 changed files with 108 additions and 86 deletions

View File

@@ -176,6 +176,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
ClearTensor(Integer(32, True), shape=(2, 3)),
],
Integer(32, True),
(3, 3),
),
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),

View File

@@ -1681,6 +1681,61 @@ def test_compile_and_run_constant_dot_correctness(
(3,),
(0, 3),
),
pytest.param(
(5,),
(4, 5, 3),
(0, 5),
),
pytest.param(
(4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5,),
(2, 4, 5, 3),
(0, 5),
),
pytest.param(
(2, 4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
pytest.param(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
pytest.param(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
],
)
def test_compile_and_run_matmul_correctness(