From 7e4d2cb6a219239e0cd68a409ec9761744670fff Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 28 Feb 2023 14:13:03 +0100 Subject: [PATCH] feat: improve graph formatting with float bounds --- concrete/numpy/compilation/client.py | 8 ++++++-- .../numpy/extensions/round_bit_pattern.py | 2 +- concrete/numpy/representation/graph.py | 19 ++++++++++++++++--- tests/mlir/test_graph_converter.py | 4 ++-- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index e51b6e1a0..0b0775cd5 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -228,7 +228,9 @@ class Client: output < (np.prod(crt_decomposition) // 2), output, -np.prod(crt_decomposition) + output, - ).astype(np.int64) + ).astype( + np.int64 + ) # type: ignore sanitized_outputs.append(sanititzed_output) @@ -242,7 +244,9 @@ class Client: output = output.astype(np.longlong) # to prevent overflows in numpy sanititzed_output = np.where( output < (2 ** (n - 1)), output, output - (2**n) - ).astype(np.int64) + ).astype( + np.int64 + ) # type: ignore sanitized_outputs.append(sanititzed_output) else: sanitized_outputs.append( diff --git a/concrete/numpy/extensions/round_bit_pattern.py b/concrete/numpy/extensions/round_bit_pattern.py index de11800d0..288479e6e 100644 --- a/concrete/numpy/extensions/round_bit_pattern.py +++ b/concrete/numpy/extensions/round_bit_pattern.py @@ -189,7 +189,7 @@ def round_bit_pattern( if isinstance(lsbs_to_remove, AutoRounder): if local._is_adjusting: if not lsbs_to_remove.is_adjusted: - raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) + raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) # type: ignore elif not lsbs_to_remove.is_adjusted: message = ( diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index 8e46eba1d..ef317c000 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -282,13 +282,26 @@ class Graph: if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]: subgraphs[line] = node.properties["kwargs"]["subgraph"] + # get formatted bounds + bounds = "" + if node.bounds is not None: + bounds += "∈ [" + + lower, upper = node.bounds + assert type(lower) == type(upper) # pylint: disable=unidiomatic-typecheck + + if isinstance(lower, (float, np.float32, np.float64)): + bounds += f"{round(lower, 6)}, {round(upper, 6)}" + else: + bounds += f"{int(lower)}, {int(upper)}" + + bounds += "]" + # remember metadata of the node line_metadata.append( { "type": f"# {node.output}", - "bounds": ( - f"∈ [{node.bounds[0]}, {node.bounds[1]}]" if node.bounds is not None else "" - ), + "bounds": bounds, "tag": (f"@ {node.tag}" if node.tag != "" else ""), "location": node.location, }, diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 72c8da990..3e48401ee 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -86,8 +86,8 @@ return %2 Function you are trying to compile cannot be converted to MLIR %0 = x # EncryptedScalar ∈ [0, 99] -%1 = sin(%0) # EncryptedScalar ∈ [-0.9999902065507035, 0.9999118601072672] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%1 = sin(%0) # EncryptedScalar ∈ [-0.99999, 0.999912] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported return %1 """, # noqa: E501