From 6affa54473c919ae937c109fed9bce7b4b4b23dc Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 7 Oct 2021 13:31:12 +0300 Subject: [PATCH] feat(configuration): add option to treat warnings as errors --- .../bounds_measurement/inputset_eval.py | 20 ++++++++++--- concrete/common/compilation/configuration.py | 3 ++ concrete/numpy/compile.py | 9 ++++-- .../bounds_measurement/test_inputset_eval.py | 28 +++++++++++++++++++ tests/numpy/test_compile.py | 14 ++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index 52cd6c21e..88051b441 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -59,6 +59,7 @@ def _print_input_coherency_warnings( parameters: Dict[str, Any], parameter_index_to_parameter_name: Dict[int, str], get_base_value_for_constant_data_func: Callable[[Any], Any], + treat_warnings_as_errors: bool, ): """Print coherency warning for `input_to_check` against `parameters`. @@ -84,11 +85,20 @@ def _print_input_coherency_warnings( parameters, get_base_value_for_constant_data_func, ) - for problem in problems: - sys.stderr.write( - f"Warning: Input #{current_input_index} (0-indexed) " - f"is not coherent with the hinted parameters ({problem})\n", + messages = [ + ( + f"Input #{current_input_index} (0-indexed) " + f"is not coherent with the hinted parameters ({problem})\n" ) + for problem in problems + ] + + if len(messages) > 0: + if treat_warnings_as_errors: + raise ValueError(", ".join(messages)) + + for message in messages: + sys.stderr.write(f"Warning: {message}") def eval_op_graph_bounds_on_inputset( @@ -161,6 +171,7 @@ def eval_op_graph_bounds_on_inputset( parameters, parameter_index_to_parameter_name, get_base_value_for_constant_data_func, + compilation_configuration.treat_warnings_as_errors, ) first_output = op_graph.evaluate(current_input_data) @@ -184,6 +195,7 @@ def eval_op_graph_bounds_on_inputset( parameters, parameter_index_to_parameter_name, get_base_value_for_constant_data_func, + compilation_configuration.treat_warnings_as_errors, ) current_output = op_graph.evaluate(current_input_data) diff --git a/concrete/common/compilation/configuration.py b/concrete/common/compilation/configuration.py index 07f909e6d..ad3bf1f86 100644 --- a/concrete/common/compilation/configuration.py +++ b/concrete/common/compilation/configuration.py @@ -7,13 +7,16 @@ class CompilationConfiguration: dump_artifacts_on_unexpected_failures: bool enable_topological_optimizations: bool check_every_input_in_inputset: bool + treat_warnings_as_errors: bool def __init__( self, dump_artifacts_on_unexpected_failures: bool = True, enable_topological_optimizations: bool = True, check_every_input_in_inputset: bool = False, + treat_warnings_as_errors: bool = False, ): self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures self.enable_topological_optimizations = enable_topological_optimizations self.check_every_input_in_inputset = check_every_input_in_inputset + self.treat_warnings_as_errors = treat_warnings_as_errors diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index a5d499996..bf8419715 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -137,12 +137,17 @@ def _compile_numpy_function_into_op_graph_internal( minimum_required_inputset_size = min(inputset_size_upper_limit, 10) if inputset_size < minimum_required_inputset_size: - sys.stderr.write( - f"Warning: Provided inputset contains too few inputs " + message = ( + f"Provided inputset contains too few inputs " f"(it should have had at least {minimum_required_inputset_size} " f"but it only had {inputset_size})\n" ) + if compilation_configuration.treat_warnings_as_errors: + raise ValueError(message) + + sys.stderr.write(f"Warning: {message}") + # Add the bounds as an artifact compilation_artifacts.add_final_operation_graph_bounds(node_bounds) diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index d977fce93..209471873 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -445,3 +445,31 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys) "(expected ClearTensor, shape=(3,)> for parameter `y` " "but got ClearTensor, shape=(3,)> which is not compatible)\n" ) + + +def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors(): + """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors""" + + def f(x, y): + return np.dot(x, y) + + x = EncryptedTensor(UnsignedInteger(2), (3,)) + y = ClearTensor(UnsignedInteger(2), (3,)) + + inputset = [ + (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), + (np.array([3, 3, 3]), np.array([3, 3, 5])), + ] + + op_graph = trace_numpy_function(f, {"x": x, "y": y}) + + with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"): + configuration = CompilationConfiguration(treat_warnings_as_errors=True) + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index afa6bb3e1..022b8cfb2 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -333,6 +333,20 @@ def test_small_inputset(): ) +def test_small_inputset_treat_warnings_as_errors(): + """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" + with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"): + compile_numpy_function_into_op_graph( + lambda x: x + 42, + {"x": EncryptedScalar(Integer(5, is_signed=False))}, + [(0,), (3,)], + CompilationConfiguration( + dump_artifacts_on_unexpected_failures=False, + treat_warnings_as_errors=True, + ), + ) + + @pytest.mark.parametrize( "function,params,shape,ref_graph_str", [