From 1aad4d23d1a0274599d9dfe4417b9a71f12d4320 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 9 Nov 2021 17:54:37 +0100 Subject: [PATCH] chore: be more verbose in this assert closes #729 --- concrete/numpy/compile.py | 18 ++++++++++++++++++ tests/numpy/test_compile.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index e99bdee63..f8f7e6f61 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -12,6 +12,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer from ..common.debugging import format_operation_graph +from ..common.debugging.custom_assert import assert_true from ..common.fhe_circuit import FHECircuit from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS from ..common.mlir.utils import ( @@ -83,6 +84,23 @@ def _compile_numpy_function_into_op_graph_internal( OPGraph: compiled function into a graph """ + # Check function parameters + wrong_inputs = { + inp: function_parameters[inp] + for inp in function_parameters.keys() + if not isinstance(function_parameters[inp], BaseValue) + } + list_of_possible_basevalue = [ + "ClearTensor", + "EncryptedTensor", + "ClearScalar", + "EncryptedScalar", + ] + assert_true( + len(wrong_inputs.keys()) == 0, + f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}", + ) + # Add the function to compile as an artifact compilation_artifacts.add_function_to_compile(function_to_compile) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 7fff439ff..9c998b63e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1481,3 +1481,33 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration): ) assert str(error) == expected raise + + +def test_wrong_inputs(default_compilation_configuration): + """Test compilation with faulty inputs""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + # x should have been something like EncryptedScalar(UnsignedInteger(3)) + x = [1, 2, 3] + input_ranges = ((0, 10),) + inputset = data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)) + dict_for_inputs = {"x": x} + + with pytest.raises(AssertionError) as excinfo: + compile_numpy_function( + lambda x: 2 * x, dict_for_inputs, inputset, default_compilation_configuration + ) + + list_of_possible_basevalue = [ + "ClearTensor", + "EncryptedTensor", + "ClearScalar", + "EncryptedScalar", + ] + assert ( + str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, " + f"needs to be one of {list_of_possible_basevalue}" + )