feat: support grouped conv2D

This commit is contained in:
youben11
2022-08-16 12:17:55 +01:00
committed by Ayoub Benaissa
parent 039a632c72
commit 0aadb4ac43
3 changed files with 25 additions and 11 deletions

View File

@@ -427,6 +427,9 @@ class NodeConverter:
type=integer_type,
context=self.ctx,
)
group = IntegerAttr.get(
IntegerType.get_signless(64), self.node.properties["kwargs"]["group"]
)
has_bias = len(self.node.inputs) == 3
if has_bias:
@@ -436,7 +439,13 @@ class NodeConverter:
# input and weight
preds = self.preds[:2]
return fhelinalg.Conv2dOp(
resulting_type, *preds, bias=bias, padding=pads, strides=strides, dilations=dilations
resulting_type,
*preds,
bias=bias,
padding=pads,
strides=strides,
dilations=dilations,
group=group,
).result
def _convert_conv3d(self) -> OpResult:

View File

@@ -95,8 +95,6 @@ def conv(
if not isinstance(group, int) or group <= 0:
raise ValueError(f"expected group to be an integer > 0, but got {group}")
if group != 1:
raise NotImplementedError("only group == 1 is currently supported")
if auto_pad not in SUPPORTED_AUTO_PAD:
raise ValueError(f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}")
@@ -107,8 +105,8 @@ def conv(
f"expected number of channel in weight to be {n_channels / group} (C / group), but got "
f"{weight.shape[1]}"
)
# TODO: no cover as we don't support group != 1 for the moment
if weight.shape[0] % group != 0: # pragma: no cover
if weight.shape[0] % group != 0:
raise ValueError(
f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group "
f"({group})"

View File

@@ -12,15 +12,22 @@ from concrete.numpy.tracing.tracer import Tracer
@pytest.mark.parametrize(
"input_shape,weight_shape",
"input_shape,weight_shape, group",
[
pytest.param(
(1, 1, 4, 4),
(1, 1, 2, 2),
1,
),
pytest.param(
(4, 3, 4, 4),
(2, 3, 2, 2),
1,
),
pytest.param(
(1, 6, 4, 4),
(6, 1, 2, 2),
6,
),
],
)
@@ -43,7 +50,7 @@ from concrete.numpy.tracing.tracer import Tracer
False,
],
)
def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers):
def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias, helpers):
"""
Test conv2d.
"""
@@ -59,7 +66,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
@cnp.compiler({"x": "encrypted"})
def function(x):
return connx.conv(x, weight, bias, strides=strides, dilations=dilations)
return connx.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
circuit = function.compile(inputset, configuration)
@@ -330,7 +337,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
"only 1D, 2D, and 3D convolutions are supported",
),
pytest.param(
(1, 1, 4, 4),
(1, 2, 4, 4),
(1, 1, 2, 2),
(1,),
(0, 0, 0, 0),
@@ -339,8 +346,8 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
None,
2,
"NOTSET",
NotImplementedError,
"only group == 1 is currently supported",
ValueError,
"expected number of feature maps (1) to be a multiple of group (2)",
),
],
)