mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support grouped conv2D
This commit is contained in:
@@ -427,6 +427,9 @@ class NodeConverter:
|
||||
type=integer_type,
|
||||
context=self.ctx,
|
||||
)
|
||||
group = IntegerAttr.get(
|
||||
IntegerType.get_signless(64), self.node.properties["kwargs"]["group"]
|
||||
)
|
||||
|
||||
has_bias = len(self.node.inputs) == 3
|
||||
if has_bias:
|
||||
@@ -436,7 +439,13 @@ class NodeConverter:
|
||||
# input and weight
|
||||
preds = self.preds[:2]
|
||||
return fhelinalg.Conv2dOp(
|
||||
resulting_type, *preds, bias=bias, padding=pads, strides=strides, dilations=dilations
|
||||
resulting_type,
|
||||
*preds,
|
||||
bias=bias,
|
||||
padding=pads,
|
||||
strides=strides,
|
||||
dilations=dilations,
|
||||
group=group,
|
||||
).result
|
||||
|
||||
def _convert_conv3d(self) -> OpResult:
|
||||
|
||||
@@ -95,8 +95,6 @@ def conv(
|
||||
|
||||
if not isinstance(group, int) or group <= 0:
|
||||
raise ValueError(f"expected group to be an integer > 0, but got {group}")
|
||||
if group != 1:
|
||||
raise NotImplementedError("only group == 1 is currently supported")
|
||||
|
||||
if auto_pad not in SUPPORTED_AUTO_PAD:
|
||||
raise ValueError(f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}")
|
||||
@@ -107,8 +105,8 @@ def conv(
|
||||
f"expected number of channel in weight to be {n_channels / group} (C / group), but got "
|
||||
f"{weight.shape[1]}"
|
||||
)
|
||||
# TODO: no cover as we don't support group != 1 for the moment
|
||||
if weight.shape[0] % group != 0: # pragma: no cover
|
||||
|
||||
if weight.shape[0] % group != 0:
|
||||
raise ValueError(
|
||||
f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group "
|
||||
f"({group})"
|
||||
|
||||
@@ -12,15 +12,22 @@ from concrete.numpy.tracing.tracer import Tracer
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,weight_shape",
|
||||
"input_shape,weight_shape, group",
|
||||
[
|
||||
pytest.param(
|
||||
(1, 1, 4, 4),
|
||||
(1, 1, 2, 2),
|
||||
1,
|
||||
),
|
||||
pytest.param(
|
||||
(4, 3, 4, 4),
|
||||
(2, 3, 2, 2),
|
||||
1,
|
||||
),
|
||||
pytest.param(
|
||||
(1, 6, 4, 4),
|
||||
(6, 1, 2, 2),
|
||||
6,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -43,7 +50,7 @@ from concrete.numpy.tracing.tracer import Tracer
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers):
|
||||
def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias, helpers):
|
||||
"""
|
||||
Test conv2d.
|
||||
"""
|
||||
@@ -59,7 +66,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(x, weight, bias, strides=strides, dilations=dilations)
|
||||
return connx.conv(x, weight, bias, strides=strides, dilations=dilations, group=group)
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
circuit = function.compile(inputset, configuration)
|
||||
@@ -330,7 +337,7 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
|
||||
"only 1D, 2D, and 3D convolutions are supported",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4, 4),
|
||||
(1, 2, 4, 4),
|
||||
(1, 1, 2, 2),
|
||||
(1,),
|
||||
(0, 0, 0, 0),
|
||||
@@ -339,8 +346,8 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
|
||||
None,
|
||||
2,
|
||||
"NOTSET",
|
||||
NotImplementedError,
|
||||
"only group == 1 is currently supported",
|
||||
ValueError,
|
||||
"expected number of feature maps (1) to be a multiple of group (2)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user