fix: correct matmul reshapes for 1D cases

- add tests for these cases with non square 2D matrices
This commit is contained in:
Arthur Meyre
2021-12-21 09:52:02 +01:00
parent a87e5ab53c
commit 2dc070dd4b
2 changed files with 10 additions and 2 deletions

View File

@@ -661,11 +661,9 @@ def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
rhs = rhs.reshape((rhs_output.shape[0], 1))
elif lhs_output.ndim == 1:
# lhs is a vector, reshape to be 2D and give proper result
output_shape = lhs_output.shape
lhs = lhs.reshape((1, lhs_output.shape[0]))
elif rhs_output.ndim == 1:
# rhs is a vector, reshape to be 2D and give proper result
output_shape = rhs_output.shape
rhs = rhs.reshape((rhs_output.shape[0], 1))
traced_computation = MatMul(

View File

@@ -1359,6 +1359,16 @@ def test_compile_and_run_constant_dot_correctness(
(2, 3),
(-4, 3),
),
pytest.param(
(5,),
(5, 3),
(0, 4),
),
pytest.param(
(5, 3),
(3,),
(0, 4),
),
],
)
def test_compile_and_run_matmul_correctness(