diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 22ea92229..6ce5a0b87 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -427,6 +427,9 @@ class NodeConverter: type=integer_type, context=self.ctx, ) + group = IntegerAttr.get( + IntegerType.get_signless(64), self.node.properties["kwargs"]["group"] + ) has_bias = len(self.node.inputs) == 3 if has_bias: @@ -436,7 +439,13 @@ class NodeConverter: # input and weight preds = self.preds[:2] return fhelinalg.Conv2dOp( - resulting_type, *preds, bias=bias, padding=pads, strides=strides, dilations=dilations + resulting_type, + *preds, + bias=bias, + padding=pads, + strides=strides, + dilations=dilations, + group=group, ).result def _convert_conv3d(self) -> OpResult: diff --git a/concrete/onnx/convolution.py b/concrete/onnx/convolution.py index 761a49a8d..074b5fb0f 100644 --- a/concrete/onnx/convolution.py +++ b/concrete/onnx/convolution.py @@ -95,8 +95,6 @@ def conv( if not isinstance(group, int) or group <= 0: raise ValueError(f"expected group to be an integer > 0, but got {group}") - if group != 1: - raise NotImplementedError("only group == 1 is currently supported") if auto_pad not in SUPPORTED_AUTO_PAD: raise ValueError(f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}") @@ -107,8 +105,8 @@ def conv( f"expected number of channel in weight to be {n_channels / group} (C / group), but got " f"{weight.shape[1]}" ) - # TODO: no cover as we don't support group != 1 for the moment - if weight.shape[0] % group != 0: # pragma: no cover + + if weight.shape[0] % group != 0: raise ValueError( f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group " f"({group})" diff --git a/tests/execution/test_convolution.py b/tests/execution/test_convolution.py index 285b04b8b..d03443132 100644 --- a/tests/execution/test_convolution.py +++ b/tests/execution/test_convolution.py @@ -12,15 +12,22 @@ from concrete.numpy.tracing.tracer import Tracer @pytest.mark.parametrize( - "input_shape,weight_shape", + "input_shape,weight_shape, group", [ pytest.param( (1, 1, 4, 4), (1, 1, 2, 2), + 1, ), pytest.param( (4, 3, 4, 4), (2, 3, 2, 2), + 1, + ), + pytest.param( + (1, 6, 4, 4), + (6, 1, 2, 2), + 6, ), ], ) @@ -43,7 +50,7 @@ from concrete.numpy.tracing.tracer import Tracer False, ], ) -def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers): +def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias, helpers): """ Test conv2d. """ @@ -59,7 +66,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers @cnp.compiler({"x": "encrypted"}) def function(x): - return connx.conv(x, weight, bias, strides=strides, dilations=dilations) + return connx.conv(x, weight, bias, strides=strides, dilations=dilations, group=group) inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)] circuit = function.compile(inputset, configuration) @@ -330,7 +337,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers "only 1D, 2D, and 3D convolutions are supported", ), pytest.param( - (1, 1, 4, 4), + (1, 2, 4, 4), (1, 1, 2, 2), (1,), (0, 0, 0, 0), @@ -339,8 +346,8 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers None, 2, "NOTSET", - NotImplementedError, - "only group == 1 is currently supported", + ValueError, + "expected number of feature maps (1) to be a multiple of group (2)", ), ], )