diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 6189d473e..f48254b4f 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -71,9 +71,24 @@ class Compiler: f"{'are' if len(missing_args) > 1 else 'is'} not provided" ) - additional_args = parameter_encryption_statuses.keys() - signature.parameters.keys() - for arg in additional_args: - del parameter_encryption_statuses[arg] + additional_args = list(parameter_encryption_statuses) + for arg in signature.parameters.keys(): + if arg in parameter_encryption_statuses: + additional_args.remove(arg) + + if len(additional_args) != 0: + parameter_str = repr(additional_args[0]) + for arg in additional_args[1:-1]: + parameter_str += f", {repr(arg)}" + if len(additional_args) != 1: + parameter_str += f" and {repr(additional_args[-1])}" + + raise ValueError( + f"Encryption status{'es' if len(additional_args) > 1 else ''} " + f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but " + f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter " + f"of function '{function.__name__}'" + ) self.function = function # type: ignore self.parameter_encryption_statuses = { @@ -155,6 +170,25 @@ class Compiler: """ if inputset is not None: + for index, sample in enumerate(iter(inputset)): + if not isinstance(sample, tuple): + sample = (sample,) + + if len(sample) != len(self.parameter_encryption_statuses): + expected = ( + "a single value" + if len(self.parameter_encryption_statuses) == 1 + else f"a tuple of {len(self.parameter_encryption_statuses)} values" + ) + actual = ( + "a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values" + ) + + raise ValueError( + f"Input #{index} of your inputset is not well formed " + f"(expected {expected} got {actual})" + ) + for input_ in inputset: self.inputset.append(input_) diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index 486f99e38..c0fad81d4 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -45,12 +45,63 @@ def test_compiler_bad_init(): "Encryption status of parameter 'x' of function 'f' is not provided" ) - # additional p + # additional a, b, c + # ------------------ + with pytest.raises(ValueError) as excinfo: + Compiler( + f, + { + "x": "encrypted", + "y": "encrypted", + "z": "encrypted", + "a": "encrypted", + "b": "encrypted", + "c": "encrypted", + }, + ) + + assert str(excinfo.value) == ( + "Encryption statuses of 'a', 'b' and 'c' are provided " + "but they are not a parameter of function 'f'" + ) + + # additional a and b + # ------------------ + + with pytest.raises(ValueError) as excinfo: + Compiler( + f, + { + "x": "encrypted", + "y": "encrypted", + "z": "encrypted", + "a": "encrypted", + "b": "encrypted", + }, + ) + + assert str(excinfo.value) == ( + "Encryption statuses of 'a' and 'b' are provided " + "but they are not a parameter of function 'f'" + ) + + # additional a # ------------ - # this is fine and `p` is just ignored + with pytest.raises(ValueError) as excinfo: + Compiler( + f, + { + "x": "encrypted", + "y": "encrypted", + "z": "encrypted", + "a": "encrypted", + }, + ) - Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"}) + assert str(excinfo.value) == ( + "Encryption status of 'a' is provided but it is not a parameter of function 'f'" + ) def test_compiler_bad_call(): @@ -98,6 +149,9 @@ def test_compiler_bad_compile(helpers): def f(x, y, z): return x + y + z + # without inputset + # ---------------- + with pytest.raises(RuntimeError) as excinfo: compiler = Compiler( f, @@ -107,6 +161,41 @@ def test_compiler_bad_compile(helpers): assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported" + # with bad inputset at the first input + # ------------------------------------ + + with pytest.raises(ValueError) as excinfo: + compiler = Compiler( + f, + {"x": "encrypted", "y": "encrypted", "z": "clear"}, + ) + inputset = [1] + compiler.compile(inputset, configuration=configuration) + + assert str(excinfo.value) == ( + "Input #0 of your inputset is not well formed " + "(expected a tuple of 3 values got a single value)" + ) + + # with bad inputset at the second input + # ------------------------------------- + + with pytest.raises(ValueError) as excinfo: + compiler = Compiler( + f, + {"x": "encrypted", "y": "encrypted", "z": "clear"}, + ) + inputset = [(1, 2, 3), (1, 2)] + compiler.compile(inputset, configuration=configuration) + + assert str(excinfo.value) == ( + "Input #1 of your inputset is not well formed " + "(expected a tuple of 3 values got a tuple of 2 values)" + ) + + # with bad configuration + # ---------------------- + with pytest.raises(RuntimeError) as excinfo: compiler = Compiler(lambda x: x, {"x": "encrypted"}) compiler.compile(