From 7901d8868ca419e55e97294f68f201ee6549f103 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 5 Oct 2022 10:59:46 +0100 Subject: [PATCH] fix: adhere to ONNX spec of kernel_shape ONNX spec: "kernel_shape : list of ints The shape of the convolution kernel. If not present, should be inferred from input W." We were taking the number of input/output feature maps into account, which we realized we should not. --- concrete/onnx/convolution.py | 4 ++-- tests/execution/test_convolution.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/concrete/onnx/convolution.py b/concrete/onnx/convolution.py index 074b5fb0f..86700b2b7 100644 --- a/concrete/onnx/convolution.py +++ b/concrete/onnx/convolution.py @@ -64,9 +64,9 @@ def conv( Union[np.ndarray, Tracer]: evaluation result or traced computation """ if kernel_shape is not None and ( - weight.ndim != len(kernel_shape) or not np.equal(weight.shape, kernel_shape) + (weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape) ): - raise ValueError(f"expected kernel_shape to be {weight.shape}, but got {kernel_shape}") + raise ValueError(f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}") if isinstance(x, np.ndarray): if not isinstance(weight, np.ndarray): diff --git a/tests/execution/test_convolution.py b/tests/execution/test_convolution.py index d03443132..0b380385d 100644 --- a/tests/execution/test_convolution.py +++ b/tests/execution/test_convolution.py @@ -269,7 +269,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias, 1, "NOTSET", ValueError, - "expected kernel_shape to be (1, 1, 2, 2), but got (1, 2)", + "expected kernel_shape to be (2, 2), but got (1, 2)", ), pytest.param( (1, 1, 4, 4),