From d50b2c1547cb28e912364e817fde4e118fd04b7c Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 22 Jul 2022 13:37:05 +0200 Subject: [PATCH] feat: detect invalid values during bounds measurement --- concrete/numpy/representation/graph.py | 41 ++++++--- concrete/numpy/tracing/tracer.py | 16 +++- tests/compilation/test_compiler.py | 122 +++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 15 deletions(-) diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index 5e77ad7eb..3f8e3ffeb 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -82,7 +82,14 @@ class Graph: continue pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)] - node_results[node] = node(*pred_results) + try: + node_results[node] = node(*pred_results) + except Exception as error: + raise RuntimeError( + "Evaluation of the graph failed\n\n" + + self.format(highlighted_nodes={node: ["evaluation of this node failed"]}) + ) from error + return node_results def draw( @@ -364,24 +371,30 @@ class Graph: if not isinstance(sample, tuple): sample = (sample,) - evaluation = self.evaluate(*sample) - for node, value in evaluation.items(): - bounds[node] = { - "min": value.min(), - "max": value.max(), - } - - for sample in inputset_iterator: - if not isinstance(sample, tuple): - sample = (sample,) - + index = 0 + try: evaluation = self.evaluate(*sample) for node, value in evaluation.items(): bounds[node] = { - "min": np.minimum(bounds[node]["min"], value.min()), - "max": np.maximum(bounds[node]["max"], value.max()), + "min": value.min(), + "max": value.max(), } + for sample in inputset_iterator: + index += 1 + if not isinstance(sample, tuple): + sample = (sample,) + + evaluation = self.evaluate(*sample) + for node, value in evaluation.items(): + bounds[node] = { + "min": np.minimum(bounds[node]["min"], value.min()), + "max": np.maximum(bounds[node]["max"], value.max()), + } + + except Exception as error: + raise RuntimeError(f"Bound measurement using inputset[{index}] failed") from error + return bounds def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]): diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 7c4440381..a5b996d08 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -557,11 +557,25 @@ class Tracer: output_value = deepcopy(self.output) output_value.dtype = Value.of(normalized_dtype.type(0)).dtype + if np.issubdtype(normalized_dtype.type, np.integer): + + def evaluator(x, dtype): + if np.any(np.isnan(x)): + raise ValueError("A `NaN` value is tried to be converted to integer") + if np.any(np.isinf(x)): + raise ValueError("An `Inf` value is tried to be converted to integer") + return x.astype(dtype) + + else: + + def evaluator(x, dtype): + return x.astype(dtype) + computation = Node.generic( "astype", [self.output], output_value, - lambda x, dtype: x.astype(dtype), + evaluator, kwargs={"dtype": normalized_dtype.type}, ) return Tracer(computation, [self]) diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index b3915b2f1..cdc161ba4 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -241,3 +241,125 @@ def test_compiler_virtual_compile(helpers): circuit = compiler.compile(inputset, configuration=configuration, virtual=True) assert circuit.encrypt_run_decrypt(200) == 600 + + +def test_compiler_compile_bad_inputset(helpers): + """ + Test `compile` method of `Compiler` class with bad inputset. + """ + + configuration = helpers.configuration() + + # with inf + # -------- + + def f(x): + return (x + np.inf).astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + compiler = Compiler(f, {"x": "encrypted"}) + compiler.compile(range(10), configuration=configuration) + + assert str(excinfo.value) == "Bound measurement using inputset[0] failed" + + assert ( + str(excinfo.value.__cause__).strip() + == """ + +Evaluation of the graph failed + +%0 = x # EncryptedScalar +%1 = subgraph(%0) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed +return %1 + +Subgraphs: + + %1 = subgraph(%0): + + %0 = inf # ClearScalar + %1 = input # EncryptedScalar + %2 = add(%1, %0) # EncryptedScalar + %3 = astype(%2, dtype=int_) # EncryptedScalar + return %3 + + """.strip() + ) + + assert ( + str(excinfo.value.__cause__.__cause__).strip() + == """ + +Evaluation of the graph failed + +%0 = inf # ClearScalar +%1 = input # EncryptedScalar +%2 = add(%1, %0) # EncryptedScalar +%3 = astype(%2, dtype=int_) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed +return %3 + + """.strip() + ) + + assert ( + str(excinfo.value.__cause__.__cause__.__cause__) + == "An `Inf` value is tried to be converted to integer" + ) + + # with nan + # -------- + + def g(x): + return (x + np.nan).astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + compiler = Compiler(g, {"x": "encrypted"}) + compiler.compile(range(10), configuration=configuration) + + assert str(excinfo.value) == "Bound measurement using inputset[0] failed" + + assert ( + str(excinfo.value.__cause__).strip() + == """ + +Evaluation of the graph failed + +%0 = x # EncryptedScalar +%1 = subgraph(%0) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed +return %1 + +Subgraphs: + + %1 = subgraph(%0): + + %0 = nan # ClearScalar + %1 = input # EncryptedScalar + %2 = add(%1, %0) # EncryptedScalar + %3 = astype(%2, dtype=int_) # EncryptedScalar + return %3 + + """.strip() + ) + + assert ( + str(excinfo.value.__cause__.__cause__).strip() + == """ + +Evaluation of the graph failed + +%0 = nan # ClearScalar +%1 = input # EncryptedScalar +%2 = add(%1, %0) # EncryptedScalar +%3 = astype(%2, dtype=int_) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed +return %3 + + """.strip() + ) + + assert ( + str(excinfo.value.__cause__.__cause__.__cause__) + == "A `NaN` value is tried to be converted to integer" + )