mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: detect invalid values during bounds measurement
This commit is contained in:
@@ -82,7 +82,14 @@ class Graph:
|
||||
continue
|
||||
|
||||
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
|
||||
node_results[node] = node(*pred_results)
|
||||
try:
|
||||
node_results[node] = node(*pred_results)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
"Evaluation of the graph failed\n\n"
|
||||
+ self.format(highlighted_nodes={node: ["evaluation of this node failed"]})
|
||||
) from error
|
||||
|
||||
return node_results
|
||||
|
||||
def draw(
|
||||
@@ -364,24 +371,30 @@ class Graph:
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": value.min(),
|
||||
"max": value.max(),
|
||||
}
|
||||
|
||||
for sample in inputset_iterator:
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
index = 0
|
||||
try:
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": np.minimum(bounds[node]["min"], value.min()),
|
||||
"max": np.maximum(bounds[node]["max"], value.max()),
|
||||
"min": value.min(),
|
||||
"max": value.max(),
|
||||
}
|
||||
|
||||
for sample in inputset_iterator:
|
||||
index += 1
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
evaluation = self.evaluate(*sample)
|
||||
for node, value in evaluation.items():
|
||||
bounds[node] = {
|
||||
"min": np.minimum(bounds[node]["min"], value.min()),
|
||||
"max": np.maximum(bounds[node]["max"], value.max()),
|
||||
}
|
||||
|
||||
except Exception as error:
|
||||
raise RuntimeError(f"Bound measurement using inputset[{index}] failed") from error
|
||||
|
||||
return bounds
|
||||
|
||||
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
|
||||
|
||||
@@ -557,11 +557,25 @@ class Tracer:
|
||||
output_value = deepcopy(self.output)
|
||||
output_value.dtype = Value.of(normalized_dtype.type(0)).dtype
|
||||
|
||||
if np.issubdtype(normalized_dtype.type, np.integer):
|
||||
|
||||
def evaluator(x, dtype):
|
||||
if np.any(np.isnan(x)):
|
||||
raise ValueError("A `NaN` value is tried to be converted to integer")
|
||||
if np.any(np.isinf(x)):
|
||||
raise ValueError("An `Inf` value is tried to be converted to integer")
|
||||
return x.astype(dtype)
|
||||
|
||||
else:
|
||||
|
||||
def evaluator(x, dtype):
|
||||
return x.astype(dtype)
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
output_value,
|
||||
lambda x, dtype: x.astype(dtype),
|
||||
evaluator,
|
||||
kwargs={"dtype": normalized_dtype.type},
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
@@ -241,3 +241,125 @@ def test_compiler_virtual_compile(helpers):
|
||||
circuit = compiler.compile(inputset, configuration=configuration, virtual=True)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(200) == 600
|
||||
|
||||
|
||||
def test_compiler_compile_bad_inputset(helpers):
|
||||
"""
|
||||
Test `compile` method of `Compiler` class with bad inputset.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
# with inf
|
||||
# --------
|
||||
|
||||
def f(x):
|
||||
return (x + np.inf).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(f, {"x": "encrypted"})
|
||||
compiler.compile(range(10), configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__).strip()
|
||||
== """
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = x # EncryptedScalar<uint1>
|
||||
%1 = subgraph(%0) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %1
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = inf # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__).strip()
|
||||
== """
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = inf # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__.__cause__)
|
||||
== "An `Inf` value is tried to be converted to integer"
|
||||
)
|
||||
|
||||
# with nan
|
||||
# --------
|
||||
|
||||
def g(x):
|
||||
return (x + np.nan).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(g, {"x": "encrypted"})
|
||||
compiler.compile(range(10), configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Bound measurement using inputset[0] failed"
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__).strip()
|
||||
== """
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = x # EncryptedScalar<uint1>
|
||||
%1 = subgraph(%0) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %1
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = nan # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__).strip()
|
||||
== """
|
||||
|
||||
Evaluation of the graph failed
|
||||
|
||||
%0 = nan # ClearScalar<float64>
|
||||
%1 = input # EncryptedScalar<uint1>
|
||||
%2 = add(%1, %0) # EncryptedScalar<float64>
|
||||
%3 = astype(%2, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed
|
||||
return %3
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value.__cause__.__cause__.__cause__)
|
||||
== "A `NaN` value is tried to be converted to integer"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user