fix: use pads arg for torch evaluation

as we weren't supporting padding, we neglected to use them in the plain
evaluation, but this result in a confusing error message for the user
when trying to use padding. This fixes the problem by using padding
properly during evaluation, and leaves the error up to the compiler.
This commit is contained in:
youben11
2022-10-14 14:02:43 +01:00
committed by Umut
parent adede72d08
commit 5f07a72e5c
2 changed files with 34 additions and 1 deletions

View File

@@ -618,6 +618,17 @@ def _evaluate_conv(
)
torch_conv_func = cast(Callable, torch_conv_func)
n_dim = x.ndim - 2 # remove batch_size and channel dims
torch_padding = []
for dim in range(n_dim):
if pads[dim] != pads[n_dim + dim]:
raise ValueError(
f"padding should be the same for the beginning of the dimension and its end, but "
f"got {pads[dim]} in the beginning, and {pads[n_dim + dim]} at the end for "
f"dimension {dim}"
)
torch_padding.append(pads[dim])
dtype = (
torch.float64
if np.issubdtype(x.dtype, np.floating)
@@ -630,6 +641,7 @@ def _evaluate_conv(
torch.tensor(weight, dtype=dtype),
torch.tensor(bias, dtype=dtype),
stride=strides,
padding=torch_padding,
dilation=dilations,
groups=group,
).numpy()

View File

@@ -92,6 +92,20 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
ValueError,
"auto_pad should be in {'NOTSET'}, but got 'VALID'",
),
pytest.param(
(1, 1, 1, 4),
(1, 1, 2, 2),
(1,),
(1, 0, 2, 0),
(1, 1),
(1, 1),
None,
1,
"NOTSET",
RuntimeError,
"padding should be the same for the beginning of the dimension and its end, but got "
"1 in the beginning, and 2 at the end for dimension 0",
),
pytest.param(
(1, 1, 4),
(1, 1, 2),
@@ -398,7 +412,14 @@ def test_bad_conv_compilation(
with pytest.raises(expected_error) as excinfo:
function.compile(inputset, configuration)
assert str(excinfo.value) == expected_message
# Get the root cause error
current_error = excinfo.value
cause = current_error.__cause__
while cause:
current_error = cause
cause = current_error.__cause__
assert str(current_error) == expected_message
@pytest.mark.parametrize(