diff --git a/concrete/onnx/convolution.py b/concrete/onnx/convolution.py index 074b5fb0f..86700b2b7 100644 --- a/concrete/onnx/convolution.py +++ b/concrete/onnx/convolution.py @@ -64,9 +64,9 @@ def conv( Union[np.ndarray, Tracer]: evaluation result or traced computation """ if kernel_shape is not None and ( - weight.ndim != len(kernel_shape) or not np.equal(weight.shape, kernel_shape) + (weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape) ): - raise ValueError(f"expected kernel_shape to be {weight.shape}, but got {kernel_shape}") + raise ValueError(f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}") if isinstance(x, np.ndarray): if not isinstance(weight, np.ndarray): diff --git a/tests/execution/test_convolution.py b/tests/execution/test_convolution.py index d03443132..0b380385d 100644 --- a/tests/execution/test_convolution.py +++ b/tests/execution/test_convolution.py @@ -269,7 +269,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias, 1, "NOTSET", ValueError, - "expected kernel_shape to be (1, 1, 2, 2), but got (1, 2)", + "expected kernel_shape to be (2, 2), but got (1, 2)", ), pytest.param( (1, 1, 4, 4),