feat(configuration): add option to treat warnings as errors

This commit is contained in:
Umut
2021-10-07 13:31:12 +03:00
parent 1cc7502251
commit 6affa54473
5 changed files with 68 additions and 6 deletions

View File

@@ -59,6 +59,7 @@ def _print_input_coherency_warnings(
parameters: Dict[str, Any],
parameter_index_to_parameter_name: Dict[int, str],
get_base_value_for_constant_data_func: Callable[[Any], Any],
treat_warnings_as_errors: bool,
):
"""Print coherency warning for `input_to_check` against `parameters`.
@@ -84,11 +85,20 @@ def _print_input_coherency_warnings(
parameters,
get_base_value_for_constant_data_func,
)
for problem in problems:
sys.stderr.write(
f"Warning: Input #{current_input_index} (0-indexed) "
f"is not coherent with the hinted parameters ({problem})\n",
messages = [
(
f"Input #{current_input_index} (0-indexed) "
f"is not coherent with the hinted parameters ({problem})\n"
)
for problem in problems
]
if len(messages) > 0:
if treat_warnings_as_errors:
raise ValueError(", ".join(messages))
for message in messages:
sys.stderr.write(f"Warning: {message}")
def eval_op_graph_bounds_on_inputset(
@@ -161,6 +171,7 @@ def eval_op_graph_bounds_on_inputset(
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
compilation_configuration.treat_warnings_as_errors,
)
first_output = op_graph.evaluate(current_input_data)
@@ -184,6 +195,7 @@ def eval_op_graph_bounds_on_inputset(
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
compilation_configuration.treat_warnings_as_errors,
)
current_output = op_graph.evaluate(current_input_data)

View File

@@ -7,13 +7,16 @@ class CompilationConfiguration:
dump_artifacts_on_unexpected_failures: bool
enable_topological_optimizations: bool
check_every_input_in_inputset: bool
treat_warnings_as_errors: bool
def __init__(
self,
dump_artifacts_on_unexpected_failures: bool = True,
enable_topological_optimizations: bool = True,
check_every_input_in_inputset: bool = False,
treat_warnings_as_errors: bool = False,
):
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_topological_optimizations = enable_topological_optimizations
self.check_every_input_in_inputset = check_every_input_in_inputset
self.treat_warnings_as_errors = treat_warnings_as_errors

View File

@@ -137,12 +137,17 @@ def _compile_numpy_function_into_op_graph_internal(
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
sys.stderr.write(
f"Warning: Provided inputset contains too few inputs "
message = (
f"Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)
if compilation_configuration.treat_warnings_as_errors:
raise ValueError(message)
sys.stderr.write(f"Warning: {message}")
# Add the bounds as an artifact
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)

View File

@@ -445,3 +445,31 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys)
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors():
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"):
configuration = CompilationConfiguration(treat_warnings_as_errors=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)

View File

@@ -333,6 +333,20 @@ def test_small_inputset():
)
def test_small_inputset_treat_warnings_as_errors():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"):
compile_numpy_function_into_op_graph(
lambda x: x + 42,
{"x": EncryptedScalar(Integer(5, is_signed=False))},
[(0,), (3,)],
CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
treat_warnings_as_errors=True,
),
)
@pytest.mark.parametrize(
"function,params,shape,ref_graph_str",
[