mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: use pads arg for torch evaluation
as we weren't supporting padding, we neglected to use them in the plain evaluation, but this result in a confusing error message for the user when trying to use padding. This fixes the problem by using padding properly during evaluation, and leaves the error up to the compiler.
This commit is contained in:
@@ -618,6 +618,17 @@ def _evaluate_conv(
|
||||
)
|
||||
torch_conv_func = cast(Callable, torch_conv_func)
|
||||
|
||||
n_dim = x.ndim - 2 # remove batch_size and channel dims
|
||||
torch_padding = []
|
||||
for dim in range(n_dim):
|
||||
if pads[dim] != pads[n_dim + dim]:
|
||||
raise ValueError(
|
||||
f"padding should be the same for the beginning of the dimension and its end, but "
|
||||
f"got {pads[dim]} in the beginning, and {pads[n_dim + dim]} at the end for "
|
||||
f"dimension {dim}"
|
||||
)
|
||||
torch_padding.append(pads[dim])
|
||||
|
||||
dtype = (
|
||||
torch.float64
|
||||
if np.issubdtype(x.dtype, np.floating)
|
||||
@@ -630,6 +641,7 @@ def _evaluate_conv(
|
||||
torch.tensor(weight, dtype=dtype),
|
||||
torch.tensor(bias, dtype=dtype),
|
||||
stride=strides,
|
||||
padding=torch_padding,
|
||||
dilation=dilations,
|
||||
groups=group,
|
||||
).numpy()
|
||||
|
||||
@@ -92,6 +92,20 @@ def test_conv2d(input_shape, weight_shape, group, strides, dilations, has_bias,
|
||||
ValueError,
|
||||
"auto_pad should be in {'NOTSET'}, but got 'VALID'",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 1, 4),
|
||||
(1, 1, 2, 2),
|
||||
(1,),
|
||||
(1, 0, 2, 0),
|
||||
(1, 1),
|
||||
(1, 1),
|
||||
None,
|
||||
1,
|
||||
"NOTSET",
|
||||
RuntimeError,
|
||||
"padding should be the same for the beginning of the dimension and its end, but got "
|
||||
"1 in the beginning, and 2 at the end for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
(1, 1, 4),
|
||||
(1, 1, 2),
|
||||
@@ -398,7 +412,14 @@ def test_bad_conv_compilation(
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
function.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
# Get the root cause error
|
||||
current_error = excinfo.value
|
||||
cause = current_error.__cause__
|
||||
while cause:
|
||||
current_error = cause
|
||||
cause = current_error.__cause__
|
||||
|
||||
assert str(current_error) == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user