mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: adhere to ONNX spec of kernel_shape
ONNX spec:
"kernel_shape : list of ints
The shape of the convolution kernel. If not present, should be
inferred from input W."
We were taking the number of input/output feature maps into account,
which we realized we should not.
This commit is contained in:
@@ -64,9 +64,9 @@ def conv(
|
||||
Union[np.ndarray, Tracer]: evaluation result or traced computation
|
||||
"""
|
||||
if kernel_shape is not None and (
|
||||
weight.ndim != len(kernel_shape) or not np.equal(weight.shape, kernel_shape)
|
||||
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
|
||||
):
|
||||
raise ValueError(f"expected kernel_shape to be {weight.shape}, but got {kernel_shape}")
|
||||
raise ValueError(f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}")
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not isinstance(weight, np.ndarray):
|
||||
|
||||
@@ -269,7 +269,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
|
||||
1,
|
||||
"NOTSET",
|
||||
ValueError,
|
||||
"expected kernel_shape to be (1, 1, 2, 2), but got (1, 2)",
|
||||
"expected kernel_shape to be (2, 2), but got (1, 2)",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4, 4),
|
||||
|
||||
Reference in New Issue
Block a user