feat: detect invalid values during bounds measurement

This commit is contained in:
Umut
2022-07-22 13:37:05 +02:00
parent e398a4fbd0
commit d50b2c1547
3 changed files with 164 additions and 15 deletions

View File

@@ -82,7 +82,14 @@ class Graph:
continue
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
node_results[node] = node(*pred_results)
try:
node_results[node] = node(*pred_results)
except Exception as error:
raise RuntimeError(
"Evaluation of the graph failed\n\n"
+ self.format(highlighted_nodes={node: ["evaluation of this node failed"]})
) from error
return node_results
def draw(
@@ -364,24 +371,30 @@ class Graph:
if not isinstance(sample, tuple):
sample = (sample,)
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": value.min(),
"max": value.max(),
}
for sample in inputset_iterator:
if not isinstance(sample, tuple):
sample = (sample,)
index = 0
try:
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": np.minimum(bounds[node]["min"], value.min()),
"max": np.maximum(bounds[node]["max"], value.max()),
"min": value.min(),
"max": value.max(),
}
for sample in inputset_iterator:
index += 1
if not isinstance(sample, tuple):
sample = (sample,)
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": np.minimum(bounds[node]["min"], value.min()),
"max": np.maximum(bounds[node]["max"], value.max()),
}
except Exception as error:
raise RuntimeError(f"Bound measurement using inputset[{index}] failed") from error
return bounds
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):

View File

@@ -557,11 +557,25 @@ class Tracer:
output_value = deepcopy(self.output)
output_value.dtype = Value.of(normalized_dtype.type(0)).dtype
if np.issubdtype(normalized_dtype.type, np.integer):
def evaluator(x, dtype):
if np.any(np.isnan(x)):
raise ValueError("A `NaN` value is tried to be converted to integer")
if np.any(np.isinf(x)):
raise ValueError("An `Inf` value is tried to be converted to integer")
return x.astype(dtype)
else:
def evaluator(x, dtype):
return x.astype(dtype)
computation = Node.generic(
"astype",
[self.output],
output_value,
lambda x, dtype: x.astype(dtype),
evaluator,
kwargs={"dtype": normalized_dtype.type},
)
return Tracer(computation, [self])

View File

@@ -241,3 +241,125 @@ def test_compiler_virtual_compile(helpers):
circuit = compiler.compile(inputset, configuration=configuration, virtual=True)
assert circuit.encrypt_run_decrypt(200) == 600
def test_compiler_compile_bad_inputset(helpers):
"""
Test `compile` method of `Compiler` class with bad inputset.
"""
configuration = helpers.configuration()
# with inf
# --------
def f(x):
return (x + np.inf).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
compiler = Compiler(f, {"x": "encrypted"})
compiler.compile(range(10), configuration=configuration)
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
assert (
str(excinfo.value.__cause__).strip()
== """
Evaluation of the graph failed
%0 = x # EncryptedScalar<uint1>
%1 = subgraph(%0) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
return %1
Subgraphs:
%1 = subgraph(%0):
%0 = inf # ClearScalar<float64>
%1 = input # EncryptedScalar<uint1>
%2 = add(%1, %0) # EncryptedScalar<float64>
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
return %3
""".strip()
)
assert (
str(excinfo.value.__cause__.__cause__).strip()
== """
Evaluation of the graph failed
%0 = inf # ClearScalar<float64>
%1 = input # EncryptedScalar<uint1>
%2 = add(%1, %0) # EncryptedScalar<float64>
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
return %3
""".strip()
)
assert (
str(excinfo.value.__cause__.__cause__.__cause__)
== "An `Inf` value is tried to be converted to integer"
)
# with nan
# --------
def g(x):
return (x + np.nan).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
compiler = Compiler(g, {"x": "encrypted"})
compiler.compile(range(10), configuration=configuration)
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
assert (
str(excinfo.value.__cause__).strip()
== """
Evaluation of the graph failed
%0 = x # EncryptedScalar<uint1>
%1 = subgraph(%0) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
return %1
Subgraphs:
%1 = subgraph(%0):
%0 = nan # ClearScalar<float64>
%1 = input # EncryptedScalar<uint1>
%2 = add(%1, %0) # EncryptedScalar<float64>
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
return %3
""".strip()
)
assert (
str(excinfo.value.__cause__.__cause__).strip()
== """
Evaluation of the graph failed
%0 = nan # ClearScalar<float64>
%1 = input # EncryptedScalar<uint1>
%2 = add(%1, %0) # EncryptedScalar<float64>
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
return %3
""".strip()
)
assert (
str(excinfo.value.__cause__.__cause__.__cause__)
== "A `NaN` value is tried to be converted to integer"
)