fix: adhere to ONNX spec of kernel_shape

ONNX spec:
"kernel_shape : list of ints
    The shape of the convolution kernel. If not present, should be
	inferred from input W."

We were taking the number of input/output feature maps into account,
which we realized we should not.
This commit is contained in:
youben11
2022-10-05 10:59:46 +01:00
committed by Ayoub Benaissa
parent dce0a86aa1
commit 7901d8868c
2 changed files with 3 additions and 3 deletions

View File

@@ -64,9 +64,9 @@ def conv(
Union[np.ndarray, Tracer]: evaluation result or traced computation
"""
if kernel_shape is not None and (
weight.ndim != len(kernel_shape) or not np.equal(weight.shape, kernel_shape)
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
):
raise ValueError(f"expected kernel_shape to be {weight.shape}, but got {kernel_shape}")
raise ValueError(f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}")
if isinstance(x, np.ndarray):
if not isinstance(weight, np.ndarray):

View File

@@ -269,7 +269,7 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
1,
"NOTSET",
ValueError,
"expected kernel_shape to be (1, 1, 2, 2), but got (1, 2)",
"expected kernel_shape to be (2, 2), but got (1, 2)",
),
pytest.param(
(1, 1, 4, 4),