mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(configuration): add option to treat warnings as errors
This commit is contained in:
@@ -59,6 +59,7 @@ def _print_input_coherency_warnings(
|
||||
parameters: Dict[str, Any],
|
||||
parameter_index_to_parameter_name: Dict[int, str],
|
||||
get_base_value_for_constant_data_func: Callable[[Any], Any],
|
||||
treat_warnings_as_errors: bool,
|
||||
):
|
||||
"""Print coherency warning for `input_to_check` against `parameters`.
|
||||
|
||||
@@ -84,11 +85,20 @@ def _print_input_coherency_warnings(
|
||||
parameters,
|
||||
get_base_value_for_constant_data_func,
|
||||
)
|
||||
for problem in problems:
|
||||
sys.stderr.write(
|
||||
f"Warning: Input #{current_input_index} (0-indexed) "
|
||||
f"is not coherent with the hinted parameters ({problem})\n",
|
||||
messages = [
|
||||
(
|
||||
f"Input #{current_input_index} (0-indexed) "
|
||||
f"is not coherent with the hinted parameters ({problem})\n"
|
||||
)
|
||||
for problem in problems
|
||||
]
|
||||
|
||||
if len(messages) > 0:
|
||||
if treat_warnings_as_errors:
|
||||
raise ValueError(", ".join(messages))
|
||||
|
||||
for message in messages:
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
|
||||
def eval_op_graph_bounds_on_inputset(
|
||||
@@ -161,6 +171,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
compilation_configuration.treat_warnings_as_errors,
|
||||
)
|
||||
|
||||
first_output = op_graph.evaluate(current_input_data)
|
||||
@@ -184,6 +195,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
compilation_configuration.treat_warnings_as_errors,
|
||||
)
|
||||
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
|
||||
@@ -7,13 +7,16 @@ class CompilationConfiguration:
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_topological_optimizations: bool
|
||||
check_every_input_in_inputset: bool
|
||||
treat_warnings_as_errors: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_topological_optimizations: bool = True,
|
||||
check_every_input_in_inputset: bool = False,
|
||||
treat_warnings_as_errors: bool = False,
|
||||
):
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_topological_optimizations = enable_topological_optimizations
|
||||
self.check_every_input_in_inputset = check_every_input_in_inputset
|
||||
self.treat_warnings_as_errors = treat_warnings_as_errors
|
||||
|
||||
@@ -137,12 +137,17 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
sys.stderr.write(
|
||||
f"Warning: Provided inputset contains too few inputs "
|
||||
message = (
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
|
||||
if compilation_configuration.treat_warnings_as_errors:
|
||||
raise ValueError(message)
|
||||
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
# Add the bounds as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)
|
||||
|
||||
|
||||
@@ -445,3 +445,31 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys)
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors():
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"):
|
||||
configuration = CompilationConfiguration(treat_warnings_as_errors=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
@@ -333,6 +333,20 @@ def test_small_inputset():
|
||||
)
|
||||
|
||||
|
||||
def test_small_inputset_treat_warnings_as_errors():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"):
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 42,
|
||||
{"x": EncryptedScalar(Integer(5, is_signed=False))},
|
||||
[(0,), (3,)],
|
||||
CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
treat_warnings_as_errors=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,params,shape,ref_graph_str",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user