mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: correct matmul reshapes for 1D cases
- add tests for these cases with non square 2D matrices
This commit is contained in:
@@ -661,11 +661,9 @@ def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
|
||||
rhs = rhs.reshape((rhs_output.shape[0], 1))
|
||||
elif lhs_output.ndim == 1:
|
||||
# lhs is a vector, reshape to be 2D and give proper result
|
||||
output_shape = lhs_output.shape
|
||||
lhs = lhs.reshape((1, lhs_output.shape[0]))
|
||||
elif rhs_output.ndim == 1:
|
||||
# rhs is a vector, reshape to be 2D and give proper result
|
||||
output_shape = rhs_output.shape
|
||||
rhs = rhs.reshape((rhs_output.shape[0], 1))
|
||||
|
||||
traced_computation = MatMul(
|
||||
|
||||
@@ -1359,6 +1359,16 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
(2, 3),
|
||||
(-4, 3),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 3),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 3),
|
||||
(3,),
|
||||
(0, 4),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_matmul_correctness(
|
||||
|
||||
Reference in New Issue
Block a user