diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 32dd4f9ab..727ed6e47 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -661,11 +661,9 @@ def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer): rhs = rhs.reshape((rhs_output.shape[0], 1)) elif lhs_output.ndim == 1: # lhs is a vector, reshape to be 2D and give proper result - output_shape = lhs_output.shape lhs = lhs.reshape((1, lhs_output.shape[0])) elif rhs_output.ndim == 1: # rhs is a vector, reshape to be 2D and give proper result - output_shape = rhs_output.shape rhs = rhs.reshape((rhs_output.shape[0], 1)) traced_computation = MatMul( diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index ed23f208e..f32f5f205 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1359,6 +1359,16 @@ def test_compile_and_run_constant_dot_correctness( (2, 3), (-4, 3), ), + pytest.param( + (5,), + (5, 3), + (0, 4), + ), + pytest.param( + (5, 3), + (3,), + (0, 4), + ), ], ) def test_compile_and_run_matmul_correctness(