diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 8aaf03c58..a0087b4ac 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -22,6 +22,7 @@ from .extensions import ( one, ones, round_bit_pattern, + tag, univariate, zero, zeros, diff --git a/concrete/numpy/compilation/artifacts.py b/concrete/numpy/compilation/artifacts.py index c04b3b04e..f99548206 100644 --- a/concrete/numpy/compilation/artifacts.py +++ b/concrete/numpy/compilation/artifacts.py @@ -7,11 +7,9 @@ import platform import shutil import subprocess from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union -import networkx as nx - -from ..representation import Graph, Node +from ..representation import Graph DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") @@ -25,14 +23,9 @@ class DebugArtifacts: source_code: Optional[str] parameter_encryption_statuses: Dict[str, str] - textual_representations_of_graphs: Dict[str, List[str]] - final_graph: Optional[Graph] - bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]] - mlir_to_compile: Optional[str] - client_parameters: Optional[bytes] def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): @@ -40,14 +33,9 @@ class DebugArtifacts: self.source_code = None self.parameter_encryption_statuses = {} - self.textual_representations_of_graphs = {} - self.final_graph = None - self.bounds_of_the_final_graph = None - self.mlir_to_compile = None - self.client_parameters = None def add_source_code(self, function: Union[str, Callable]): @@ -100,18 +88,6 @@ class DebugArtifacts: self.final_graph = graph - def add_final_graph_bounds(self, bounds: Dict[Node, Dict[str, Any]]): - """ - Add bounds of the latest computation graph. - - Args: - bounds (Dict[Node, Dict[str, Any]]): - bounds of the latest computation graph - """ - - assert self.final_graph is not None - self.bounds_of_the_final_graph = bounds - def add_mlir_to_compile(self, mlir: str): """ Add textual representation of the resulting MLIR. @@ -201,13 +177,6 @@ class DebugArtifacts: with open(output_path, "w", encoding="utf-8") as f: f.write(f"{representation}\n") - if self.bounds_of_the_final_graph is not None: - assert self.final_graph is not None - with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f: - for index, node in enumerate(nx.topological_sort(self.final_graph.graph)): - bounds = self.bounds_of_the_final_graph.get(node) - f.write(f"%{index} :: [{bounds['min']}, {bounds['max']}]\n") - if self.mlir_to_compile is not None: assert self.final_graph is not None with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 4181fc61a..cda05c60e 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -278,10 +278,8 @@ class Compiler: assert self.graph is not None bounds = self.graph.measure_bounds(self.inputset) - if self.artifacts is not None: - self.artifacts.add_final_graph_bounds(bounds) - self.graph.update_with_bounds(bounds) + if self.artifacts is not None: self.artifacts.add_graph("final", self.graph) diff --git a/concrete/numpy/compilation/utils.py b/concrete/numpy/compilation/utils.py index a0245ad99..7aeb3dd1c 100644 --- a/concrete/numpy/compilation/utils.py +++ b/concrete/numpy/compilation/utils.py @@ -557,14 +557,19 @@ def convert_subgraph_to_subgraph_node( variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant] if len(variable_input_nodes) != 1: - base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes} + base_highlighted_nodes = { + node: ["within this subgraph", node.location] for node in all_nodes + } for variable_input_node in variable_input_nodes: - base_highlighted_nodes[variable_input_node] = ["this is one of the input nodes"] + base_highlighted_nodes[variable_input_node] = [ + "this is one of the input nodes", + variable_input_node.location, + ] raise RuntimeError( "A subgraph within the function you are trying to compile cannot be fused " "because it has multiple input nodes\n\n" - + graph.format(highlighted_nodes=base_highlighted_nodes) + + graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False) ) variable_input_node = variable_input_nodes[0] @@ -577,6 +582,10 @@ def convert_subgraph_to_subgraph_node( subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output)) nx_subgraph.add_node(subgraph_variable_input_node) + subgraph_variable_input_node.location = variable_input_node.location + subgraph_variable_input_node.tag = variable_input_node.tag + subgraph_variable_input_node.created_at = variable_input_node.created_at + variable_input_node_successors = { node: None for node in all_nodes if node in nx_graph.succ[variable_input_node] } @@ -592,6 +601,10 @@ def convert_subgraph_to_subgraph_node( **new_edge_data, ) + original_location = terminal_node.location + original_tag = terminal_node.tag + original_created_at = terminal_node.created_at + subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node}) subgraph_node = Node.generic( "subgraph", @@ -604,6 +617,10 @@ def convert_subgraph_to_subgraph_node( }, ) + subgraph_node.location = original_location + subgraph_node.tag = original_tag + subgraph_node.created_at = original_created_at + return subgraph_node, variable_input_node @@ -635,8 +652,11 @@ def check_subgraph_fusability( if subgraph is not fusable """ - base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes} - base_highlighted_nodes[variable_input_node] = ["with this input node"] + base_highlighted_nodes = {node: ["within this subgraph", node.location] for node in all_nodes} + base_highlighted_nodes[variable_input_node] = [ + "with this input node", + variable_input_node.location, + ] non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant) for node in non_constant_nodes: @@ -644,19 +664,22 @@ def check_subgraph_fusability( continue if not node.is_fusable: - base_highlighted_nodes[node] = ["this node is not fusable"] + base_highlighted_nodes[node] = ["this node is not fusable", node.location] raise RuntimeError( "A subgraph within the function you are trying to compile cannot be fused " "because of a node, which is marked explicitly as non-fusable\n\n" - + graph.format(highlighted_nodes=base_highlighted_nodes) + + graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False) ) if node.output.shape != variable_input_node.output.shape: - base_highlighted_nodes[node] = ["this node has a different shape than the input node"] + base_highlighted_nodes[node] = [ + "this node has a different shape than the input node", + node.location, + ] raise RuntimeError( "A subgraph within the function you are trying to compile cannot be fused " "because of a node, which is has a different shape than the input node\n\n" - + graph.format(highlighted_nodes=base_highlighted_nodes) + + graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False) ) return True diff --git a/concrete/numpy/extensions/__init__.py b/concrete/numpy/extensions/__init__.py index 65229552d..7648218d9 100644 --- a/concrete/numpy/extensions/__init__.py +++ b/concrete/numpy/extensions/__init__.py @@ -6,5 +6,6 @@ from .array import array from .ones import one, ones from .round_bit_pattern import AutoRounder, round_bit_pattern from .table import LookupTable +from .tag import tag from .univariate import univariate from .zeros import zero, zeros diff --git a/concrete/numpy/extensions/tag.py b/concrete/numpy/extensions/tag.py new file mode 100644 index 000000000..904493652 --- /dev/null +++ b/concrete/numpy/extensions/tag.py @@ -0,0 +1,24 @@ +""" +Declaration of `tag` context manager, to allow tagging certain nodes. +""" + +import threading +from contextlib import contextmanager + +tag_context = threading.local() +tag_context.stack = [] + + +@contextmanager +def tag(name: str): + """ + Introduce a new tag to the tag stack. + + Can be nested, and the resulting tag will be `tag1.tag2`. + """ + + tag_context.stack.append(name) + try: + yield + finally: + tag_context.stack.pop() diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 6965ce1a6..dd4c7c8bb 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -204,7 +204,7 @@ class GraphConverter: if len(graph.output_nodes) > 1: offending_nodes.update( { - node: ["only a single output is supported"] + node: ["only a single output is supported", node.location] for node in graph.output_nodes.values() } ) @@ -213,7 +213,7 @@ class GraphConverter: for node in graph.graph.nodes: reason = GraphConverter._check_node_convertibility(graph, node, virtual) if reason is not None: - offending_nodes[node] = [reason] + offending_nodes[node] = [reason, node.location] if len(offending_nodes) != 0: raise RuntimeError( @@ -257,14 +257,16 @@ class GraphConverter: if max_bit_width > MAXIMUM_TLU_BIT_WIDTH: offending_nodes[first_tlu_node] = [ f"table lookups are only supported on circuits with " - f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers" + f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers", + first_tlu_node.location, ] if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS: offending_nodes[first_signed_node] = [ f"signed integers are only supported " f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits " - f"on circuits with table lookups" + f"on circuits with table lookups", + first_signed_node.location, ] if len(offending_nodes) != 0: diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index 5cdf40434..0c21b741a 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -2,6 +2,7 @@ Declaration of `Graph` class. """ +import re from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -25,11 +26,14 @@ class Graph: input_indices: Dict[Node, int] + is_direct: bool + def __init__( self, graph: nx.MultiDiGraph, input_nodes: Dict[int, Node], output_nodes: Dict[int, Node], + is_direct: bool = False, ): self.graph = graph @@ -38,6 +42,8 @@ class Graph: self.input_indices = {node: index for index, node in input_nodes.items()} + self.is_direct = is_direct + self.prune_useless_nodes() def __call__( @@ -82,7 +88,10 @@ class Graph: except Exception as error: raise RuntimeError( "Evaluation of the graph failed\n\n" - + self.format(highlighted_nodes={node: ["evaluation of this node failed"]}) + + self.format( + highlighted_nodes={node: ["evaluation of this node failed"]}, + show_bounds=False, + ) ) from error return node_results @@ -91,6 +100,10 @@ class Graph: self, maximum_constant_length: int = 25, highlighted_nodes: Optional[Dict[Node, List[str]]] = None, + show_types: bool = True, + show_bounds: bool = True, + show_tags: bool = True, + show_locations: bool = False, ) -> str: """ Get the textual representation of the `Graph`. @@ -102,11 +115,28 @@ class Graph: highlighted_nodes (Optional[Dict[Node, List[str]]], default = None): nodes to be highlighted and their corresponding messages + show_types (bool, default = True): + whether to show types of nodes + + show_bounds (bool, default = True): + whether to show bounds of nodes + + show_tags (bool, default = True): + whether to show tags of nodes + + show_locations (bool, default = False): + whether to show line information of nodes + Returns: str: textual representation of the `Graph` """ + # pylint: disable=too-many-branches,too-many-locals,too-many-statements + + if self.is_direct: + show_bounds = False + # node -> identifier # e.g., id_map[node1] = 2 # means line for node1 is in this form %2 = node1.format(...) @@ -115,9 +145,9 @@ class Graph: # lines that will be merged at the end lines: List[str] = [] - # type information to add to each line - # (for alingment, this is done after lines are determined) - type_informations: List[str] = [] + # metadata to add to each line + # (for alignment, this is done after lines are determined) + line_metadata: List[Dict[str, str]] = [] # default highlighted nodes is empty highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {} @@ -130,7 +160,7 @@ class Graph: subgraphs: Dict[str, Graph] = {} # format nodes - for node in nx.topological_sort(self.graph): + for node in nx.lexicographical_topological_sort(self.graph): # assign a unique id to outputs of node id_map[node] = len(id_map) @@ -160,8 +190,17 @@ class Graph: if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]: subgraphs[line] = node.properties["kwargs"]["subgraph"] - # remember type information of the node - type_informations.append(str(node.output)) + # remember metadata of the node + line_metadata.append( + { + "type": f"# {node.output}", + "bounds": ( + f"∈ [{node.bounds[0]}, {node.bounds[1]}]" if node.bounds is not None else "" + ), + "tag": (f"@ {node.tag}" if node.tag != "" else ""), + "location": node.location, + }, + ) # align = signs # @@ -182,11 +221,28 @@ class Graph: " " * (longest_length_before_equals_sign - length_before_equals_sign) ) + line - # add type information - longest_line_length = max(len(line) for line in lines) - for i, line in enumerate(lines): - lines[i] += " " * (longest_line_length - len(line)) - lines[i] += f" # {type_informations[i]}" + # determine which metadata to show + shown_metadata_keys = [] + if show_types: + shown_metadata_keys.append("type") + if show_bounds: + shown_metadata_keys.append("bounds") + if show_tags: + shown_metadata_keys.append("tag") + if show_locations: + shown_metadata_keys.append("location") + + # show requested metadata + indent = 8 + for metadata_key in shown_metadata_keys: + longest_line_length = max(len(line) for line in lines) + lines = [ + line + (" " * ((longest_line_length - len(line)) + indent)) + metadata[metadata_key] + for line, metadata in zip(lines, line_metadata) + ] + + # strip whitespaces + lines = [line.rstrip() for line in lines] # add highlights (this is done in reverse to keep indices consistent) for i in reversed(range(len(lines))): @@ -209,13 +265,23 @@ class Graph: result += "\n\n" result += "Subgraphs:" for line, subgraph in subgraphs.items(): - subgraph_lines = subgraph.format(maximum_constant_length).split("\n") + subgraph_lines = subgraph.format( + maximum_constant_length=maximum_constant_length, + highlighted_nodes={}, + show_types=show_types, + show_bounds=False, # doesn't make sense as we don't measure bounds in subgraphs + show_tags=show_tags, + show_locations=show_locations, + ).split("\n") + result += "\n\n" result += f" {line}:\n\n" result += "\n".join(f" {line}" for line in subgraph_lines) return result + # pylint: enable=too-many-branches,too-many-locals,too-many-statements + def measure_bounds( self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], @@ -300,6 +366,8 @@ class Graph: min_bound = bounds[node]["min"] max_bound = bounds[node]["max"] + node.bounds = (min_bound, max_bound) + new_value = deepcopy(node.output) if isinstance(min_bound, np.integer): @@ -384,17 +452,135 @@ class Graph: useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes] self.graph.remove_nodes_from(useless_nodes) - def maximum_integer_bit_width(self) -> int: + def query_nodes( + self, + tag_filter: Optional[Union[str, List[str], re.Pattern]] = None, + operation_filter: Optional[Union[str, List[str], re.Pattern]] = None, + ) -> List[Node]: + """ + Query nodes within the graph. + + Filters work like so: + str -> nodes without exact match is skipped + List[str] -> nodes without exact match with one of the strings in the list is skipped + re.Pattern -> nodes without pattern match is skipped + + Args: + tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for tags + + operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for operations + + Returns: + List[Node]: + filtered nodes + """ + + def match_text_filter(text_filter, text): + if text_filter is None: + return True + + if isinstance(text_filter, str): + return text == text_filter + + if isinstance(text_filter, re.Pattern): + return text_filter.match(text) + + return any(text == alternative for alternative in text_filter) + + def get_operation_name(node): + result: str + + if node.operation == Operation.Input: + result = "input" + elif node.operation == Operation.Constant: + result = "constant" + else: + result = node.properties["name"] + + return result + + return [ + node + for node in self.graph.nodes() + if ( + match_text_filter(tag_filter, node.tag) + and match_text_filter(operation_filter, get_operation_name(node)) + ) + ] + + def maximum_integer_bit_width( + self, + tag_filter: Optional[Union[str, List[str], re.Pattern]] = None, + operation_filter: Optional[Union[str, List[str], re.Pattern]] = None, + ) -> int: """ Get maximum integer bit-width within the graph. + Only nodes after filtering will be used to calculate the result. + + Args: + tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for tags + + operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for operations + Returns: int: - maximum integer bit-width within the graph (-1 is there are no integer nodes) + maximum integer bit-width within the graph + if there are no integer nodes matching the query, result is -1 """ - result = -1 - for node in self.graph.nodes(): - if isinstance(node.output.dtype, Integer): - result = max(result, node.output.dtype.bit_width) + filtered_bit_widths = ( + node.output.dtype.bit_width + for node in self.query_nodes(tag_filter, operation_filter) + if isinstance(node.output.dtype, Integer) + ) + return max(filtered_bit_widths, default=-1) + + def integer_range( + self, + tag_filter: Optional[Union[str, List[str], re.Pattern]] = None, + operation_filter: Optional[Union[str, List[str], re.Pattern]] = None, + ) -> Optional[Tuple[int, int]]: + """ + Get integer range of the graph. + + Only nodes after filtering will be used to calculate the result. + + Args: + tag_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for tags + + operation_filter (Optional[Union[str, List[str], re.Pattern]], default = None): + filter for operations + + Returns: + Optional[Tuple[int, int]]: + minimum and maximum integer value observed during inputset evaluation + if there are no integer nodes matching the query, result is None + """ + + result: Optional[Tuple[int, int]] = None + + if not self.is_direct: + filtered_bounds = ( + node.bounds + for node in self.query_nodes(tag_filter, operation_filter) + if isinstance(node.output.dtype, Integer) and node.bounds is not None + ) + for min_bound, max_bound in filtered_bounds: + assert isinstance(min_bound, np.integer) and isinstance(max_bound, np.integer) + + if result is None: + result = (int(min_bound), int(max_bound)) + else: + old_min_bound, old_max_bound = result # pylint: disable=unpacking-non-sequence + result = ( + min(old_min_bound, int(min_bound)), + max(old_max_bound, int(max_bound)), + ) + return result diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 5bb105b9f..ece8d9271 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -2,6 +2,9 @@ Declaration of `Node` class. """ +import os +import time +import traceback from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -25,8 +28,13 @@ class Node: operation: Operation evaluator: Callable + bounds: Optional[Tuple[Union[int, float], Union[int, float]]] properties: Dict[str, Any] + location: str + tag: str + created_at: float + @staticmethod def constant(constant: Any) -> "Node": """ @@ -145,8 +153,44 @@ class Node: self.operation = operation self.evaluator = evaluator # type: ignore + self.bounds = None self.properties = properties if properties is not None else {} + # pylint: disable=cyclic-import,import-outside-toplevel + + import concrete.numpy as cnp + + cnp_directory = os.path.dirname(cnp.__file__) + + import concrete.onnx as coonx + + coonx_directory = os.path.dirname(coonx.__file__) + + # pylint: enable=cyclic-import,import-outside-toplevel + + for frame in reversed(traceback.extract_stack()): + if frame.filename == "<__array_function__ internals>": + continue + + if frame.filename.startswith(cnp_directory): + continue + + if frame.filename.startswith(coonx_directory): + continue + + self.location = f"{frame.filename}:{frame.lineno}" + break + + # pylint: disable=cyclic-import,import-outside-toplevel + + from ..extensions.tag import tag_context + + self.tag = ".".join(tag_context.stack) + + # pylint: enable=cyclic-import,import-outside-toplevel + + self.created_at = time.time() + def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]: def generic_error_message() -> str: result = f"Evaluation of {self.operation.value} '{self.label()}' node" @@ -361,3 +405,6 @@ class Node: "subtract", "zeros", ] + + def __lt__(self, other) -> bool: + return self.created_at < other.created_at diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 94584efbe..399cedcc1 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -151,7 +151,7 @@ class Tracer: output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers) } - return Graph(graph, input_nodes, output_nodes) + return Graph(graph, input_nodes, output_nodes, is_direct) def __init__(self, computation: Node, input_tracers: List["Tracer"]): self.computation = computation diff --git a/tests/compilation/test_artifacts.py b/tests/compilation/test_artifacts.py index 82e436dda..380323140 100644 --- a/tests/compilation/test_artifacts.py +++ b/tests/compilation/test_artifacts.py @@ -43,7 +43,6 @@ def test_artifacts_export(helpers): assert (tmpdir / "3.after-fusing.graph.txt").exists() assert (tmpdir / "4.final.graph.txt").exists() - assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() assert (tmpdir / "client_parameters.json").exists() @@ -60,6 +59,5 @@ def test_artifacts_export(helpers): assert (tmpdir / "3.after-fusing.graph.txt").exists() assert (tmpdir / "4.final.graph.txt").exists() - assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() assert (tmpdir / "client_parameters.json").exists() diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 1d693c655..9574add59 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -25,16 +25,7 @@ def test_circuit_str(helpers): inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)] circuit = f.compile(inputset, configuration.fork(p_error=6e-5)) - assert str(circuit) == ( - """ - -%0 = x # EncryptedScalar -%1 = y # EncryptedScalar -%2 = add(%0, %1) # EncryptedScalar -return %2 - - """.strip() - ) + assert str(circuit) == circuit.graph.format() def test_circuit_feedback(helpers): diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index a5b422a58..f8ebbfec5 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -262,9 +262,8 @@ def test_compiler_compile_bad_inputset(helpers): assert str(excinfo.value) == "Bound measurement using inputset[0] failed" - assert ( - str(excinfo.value.__cause__).strip() - == """ + helpers.check_str( + """ Evaluation of the graph failed @@ -277,29 +276,30 @@ Subgraphs: %1 = subgraph(%0): - %0 = inf # ClearScalar - %1 = input # EncryptedScalar - %2 = add(%1, %0) # EncryptedScalar + %0 = input # EncryptedScalar + %1 = inf # ClearScalar + %2 = add(%0, %1) # EncryptedScalar %3 = astype(%2, dtype=int_) # EncryptedScalar return %3 - """.strip() + """.strip(), + str(excinfo.value.__cause__).strip(), ) - assert ( - str(excinfo.value.__cause__.__cause__).strip() - == """ + helpers.check_str( + """ Evaluation of the graph failed -%0 = inf # ClearScalar -%1 = input # EncryptedScalar -%2 = add(%1, %0) # EncryptedScalar +%0 = input # EncryptedScalar +%1 = inf # ClearScalar +%2 = add(%0, %1) # EncryptedScalar %3 = astype(%2, dtype=int_) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed return %3 - """.strip() + """.strip(), + str(excinfo.value.__cause__.__cause__).strip(), ) assert ( @@ -319,9 +319,8 @@ return %3 assert str(excinfo.value) == "Bound measurement using inputset[0] failed" - assert ( - str(excinfo.value.__cause__).strip() - == """ + helpers.check_str( + """ Evaluation of the graph failed @@ -334,29 +333,30 @@ Subgraphs: %1 = subgraph(%0): - %0 = nan # ClearScalar - %1 = input # EncryptedScalar - %2 = add(%1, %0) # EncryptedScalar + %0 = input # EncryptedScalar + %1 = nan # ClearScalar + %2 = add(%0, %1) # EncryptedScalar %3 = astype(%2, dtype=int_) # EncryptedScalar return %3 - """.strip() + """.strip(), + str(excinfo.value.__cause__).strip(), ) - assert ( - str(excinfo.value.__cause__.__cause__).strip() - == """ + helpers.check_str( + """ Evaluation of the graph failed -%0 = nan # ClearScalar -%1 = input # EncryptedScalar -%2 = add(%1, %0) # EncryptedScalar +%0 = input # EncryptedScalar +%1 = nan # ClearScalar +%2 = add(%0, %1) # EncryptedScalar %3 = astype(%2, dtype=int_) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ evaluation of this node failed return %3 - """.strip() + """.strip(), + str(excinfo.value.__cause__.__cause__).strip(), ) assert ( diff --git a/tests/compilation/test_decorators.py b/tests/compilation/test_decorators.py index 2356a05f8..6c3aae783 100644 --- a/tests/compilation/test_decorators.py +++ b/tests/compilation/test_decorators.py @@ -48,9 +48,9 @@ def test_compiler_verbose_trace(helpers, capsys): f""" Computation Graph ------------------------------------------------- +------------------------------------------------------------------ {str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])} ------------------------------------------------- +------------------------------------------------------------------ """.strip() ) @@ -112,19 +112,19 @@ def test_compiler_verbose_virtual_compile(helpers, capsys): f""" Computation Graph ------------------------------------------------- +------------------------------------------------------------------ {list(artifacts.textual_representations_of_graphs.values())[-1][-1]} ------------------------------------------------- +------------------------------------------------------------------ MLIR ------------------------------------------------- +------------------------------------------------------------------ Virtual circuits don't have MLIR. ------------------------------------------------- +------------------------------------------------------------------ Optimizer ------------------------------------------------- +------------------------------------------------------------------ Virtual circuits don't have optimizer output. ------------------------------------------------- +------------------------------------------------------------------ """.strip() ) @@ -140,7 +140,6 @@ def test_circuit(helpers): return x + 42 helpers.check_str( - str(circuit1), """ %0 = x # EncryptedScalar @@ -149,6 +148,7 @@ def test_circuit(helpers): return %2 """.strip(), + str(circuit1), ) # ====================================================================== @@ -158,7 +158,6 @@ return %2 return x + 42 helpers.check_str( - str(circuit2), """ %0 = x # EncryptedTensor @@ -167,6 +166,7 @@ return %2 return %2 """.strip(), + str(circuit2), ) # ====================================================================== @@ -179,7 +179,6 @@ return %2 return cnp.univariate(square, outputs=cnp.uint7)(x) helpers.check_str( - str(circuit3), """ %0 = x # EncryptedScalar @@ -187,6 +186,7 @@ return %2 return %1 """.strip(), + str(circuit3), ) # ====================================================================== @@ -196,7 +196,6 @@ return %1 return ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(cnp.uint3) helpers.check_str( - str(circuit4), """ %0 = x # EncryptedScalar @@ -207,18 +206,19 @@ Subgraphs: %1 = subgraph(%0): - %0 = 2 # ClearScalar - %1 = 2 # ClearScalar - %2 = input # EncryptedScalar - %3 = sin(%2) # EncryptedScalar - %4 = cos(%2) # EncryptedScalar - %5 = power(%3, %0) # EncryptedScalar - %6 = power(%4, %1) # EncryptedScalar - %7 = add(%5, %6) # EncryptedScalar + %0 = input # EncryptedScalar + %1 = sin(%0) # EncryptedScalar + %2 = 2 # ClearScalar + %3 = power(%1, %2) # EncryptedScalar + %4 = cos(%0) # EncryptedScalar + %5 = 2 # ClearScalar + %6 = power(%4, %5) # EncryptedScalar + %7 = add(%3, %6) # EncryptedScalar %8 = astype(%7) # EncryptedScalar return %8 """.strip(), + str(circuit4), ) # ====================================================================== @@ -228,7 +228,6 @@ Subgraphs: return x + 42 helpers.check_str( - str(circuit5), """ %0 = x # EncryptedScalar @@ -237,6 +236,7 @@ Subgraphs: return %2 """.strip(), + str(circuit5), ) diff --git a/tests/conftest.py b/tests/conftest.py index b43a226df..b0b40d61d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ Configuration of `pytest`. """ import json +import os import random from pathlib import Path from typing import Any, Callable, Dict, List, Tuple, Union @@ -11,6 +12,10 @@ import numpy as np import pytest import concrete.numpy as cnp +import tests + +tests_directory = os.path.dirname(tests.__file__) + INSECURE_KEY_CACHE_LOCATION = None @@ -279,6 +284,14 @@ Actual Output actual str """ + # remove error line information + # there are explicit tests to make sure the line information is correct + # however, it would have been very hard to keep the other tests up to date + + actual = "\n".join( + line for line in actual.splitlines() if not line.strip().startswith(tests_directory) + ) + assert ( actual.strip() == expected.strip() ), f""" diff --git a/tests/execution/test_maxpool.py b/tests/execution/test_maxpool.py index b396ff758..7c17ee915 100644 --- a/tests/execution/test_maxpool.py +++ b/tests/execution/test_maxpool.py @@ -338,7 +338,7 @@ def test_bad_maxpool_special(helpers): def clear_input(x): return connx.maxpool(x, kernel_shape=(4, 3, 2)) - inputset = [np.random.randint(0, 10, size=(1, 1, 10, 10, 10)) for i in range(100)] + inputset = [np.zeros((1, 1, 10, 10, 10), dtype=np.int64)] with pytest.raises(RuntimeError) as excinfo: clear_input.compile(inputset, helpers.configuration()) @@ -348,9 +348,9 @@ def test_bad_maxpool_special(helpers): Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported +%0 = x # ClearTensor ∈ [0, 0] +%1 = maxpool(%0, kernel_shape=(4, 3, 2), strides=(1, 1, 1), pads=(0, 0, 0, 0, 0, 0), dilations=(1, 1, 1), ceil_mode=False) # ClearTensor ∈ [0, 0] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted maxpool is supported return %1 """.strip(), # noqa: E501 diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index ca7627313..a917e243d 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -702,31 +702,31 @@ def test_others_bad_fusing(helpers): A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes - %0 = 10 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %1 = 10 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %2 = 2 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph + %0 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes + %1 = y # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes + %2 = sin(%0) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %3 = 2 # ClearScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %4 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes - %5 = y # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes - %6 = sin(%4) # EncryptedScalar + %4 = power(%2, %3) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %7 = cos(%5) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %8 = power(%6, %2) # EncryptedScalar + %5 = 10 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph + %6 = multiply(%5, %4) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph - %9 = power(%7, %3) # ClearScalar + %7 = cos(%1) # ClearScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph -%10 = multiply(%0, %8) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph -%11 = multiply(%1, %9) # ClearScalar + %8 = 2 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph + %9 = power(%7, %8) # ClearScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph -%12 = add(%10, %11) # EncryptedScalar +%10 = 10 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%11 = multiply(%10, %9) # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%12 = add(%6, %11) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %13 = astype(%12, dtype=int_) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph diff --git a/tests/execution/test_round_bit_pattern.py b/tests/execution/test_round_bit_pattern.py index c14b493a7..0d5f1f3a1 100644 --- a/tests/execution/test_round_bit_pattern.py +++ b/tests/execution/test_round_bit_pattern.py @@ -201,7 +201,7 @@ def test_auto_rounding(helpers): return %4 """, - str(circuit3), + str(circuit3.graph.format(show_bounds=False)), ) diff --git a/tests/extensions/test_tag.py b/tests/extensions/test_tag.py new file mode 100644 index 000000000..1a14e2c5f --- /dev/null +++ b/tests/extensions/test_tag.py @@ -0,0 +1,64 @@ +""" +Tests of 'tag' extension. +""" + +import numpy as np + +import concrete.numpy as cnp + + +def test_tag(helpers): + """ + Test tag extension. + """ + + def g(z): + with cnp.tag("def"): + a = 120 - z + b = a // 4 + return b + + @cnp.compiler({"x": "encrypted"}) + def f(x): + with cnp.tag("abc"): + x = x * 2 + with cnp.tag("foo"): + y = x + 42 + z = np.sqrt(y).astype(np.int64) + + return g(z + 3) * 2 + + inputset = range(10) + circuit = f.trace(inputset, configuration=helpers.configuration()) + + helpers.check_str( + """ + + %0 = x # EncryptedScalar + %1 = 2 # ClearScalar @ abc + %2 = multiply(%0, %1) # EncryptedScalar @ abc + %3 = 42 # ClearScalar @ abc.foo + %4 = add(%2, %3) # EncryptedScalar @ abc.foo + %5 = subgraph(%4) # EncryptedScalar @ abc + %6 = 3 # ClearScalar + %7 = add(%5, %6) # EncryptedScalar + %8 = 120 # ClearScalar @ def + %9 = subtract(%8, %7) # EncryptedScalar @ def +%10 = 4 # ClearScalar @ def +%11 = floor_divide(%9, %10) # EncryptedScalar @ def +%12 = 2 # ClearScalar +%13 = multiply(%11, %12) # EncryptedScalar +return %13 + +Subgraphs: + + %5 = subgraph(%4): + + %0 = input # EncryptedScalar @ abc.foo + %1 = sqrt(%0) # EncryptedScalar @ abc + %2 = astype(%1, dtype=int_) # EncryptedScalar @ abc + return %2 + + """.strip(), + circuit.format(show_bounds=False), + ) diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 7bce4cc67..bfcfa6ac6 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -26,18 +26,18 @@ def assign(x): pytest.param( lambda x, y: (x - y, x + y), {"x": "encrypted", "y": "clear"}, - [(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)], + [(0, 0), (7, 7), (0, 7), (7, 0)], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedScalar -%1 = y # ClearScalar -%2 = subtract(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported -%3 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported +%0 = x # EncryptedScalar ∈ [0, 7] +%1 = y # ClearScalar ∈ [0, 7] +%2 = subtract(%0, %1) # EncryptedScalar ∈ [-7, 7] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported +%3 = add(%0, %1) # EncryptedScalar ∈ [0, 14] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only a single output is supported return (%2, %3) """, # noqa: E501 @@ -51,8 +51,8 @@ return (%2, %3) Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported +%0 = x # ClearScalar ∈ [-10, 9] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported return %0 """, # noqa: E501 @@ -66,12 +66,12 @@ return %0 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported -%1 = 1.5 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%2 = multiply(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%0 = x # EncryptedScalar ∈ [0.0, 247.5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported +%1 = 1.5 # ClearScalar ∈ [1.5, 1.5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported +%2 = multiply(%0, %1) # EncryptedScalar ∈ [0.0, 371.25] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported return %2 """, # noqa: E501 @@ -85,9 +85,9 @@ return %2 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedScalar -%1 = sin(%0) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +%0 = x # EncryptedScalar ∈ [0, 99] +%1 = sin(%0) # EncryptedScalar ∈ [-0.9999902065507035, 0.9999118601072672] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported return %1 """, # noqa: E501 @@ -107,10 +107,10 @@ return %1 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = y # ClearTensor -%2 = concatenate((%0, %1)) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported +%0 = x # EncryptedTensor ∈ [0, 7] +%1 = y # ClearTensor ∈ [0, 7] +%2 = concatenate((%0, %1)) # EncryptedTensor ∈ [0, 7] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only all encrypted concatenate is supported return %2 """, # noqa: E501 @@ -130,10 +130,10 @@ return %2 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = w # EncryptedTensor -%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported +%0 = x # EncryptedTensor ∈ [0, 1] +%1 = w # EncryptedTensor ∈ [0, 1] +%2 = conv1d(%0, %1, [0], pads=(0, 0), strides=(1,), dilations=(1,), group=1) # EncryptedTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv1d with encrypted input and clear weight is supported return %2 """, # noqa: E501 @@ -153,10 +153,10 @@ return %2 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = w # EncryptedTensor -%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported +%0 = x # EncryptedTensor ∈ [0, 1] +%1 = w # EncryptedTensor ∈ [0, 1] +%2 = conv2d(%0, %1, [0], pads=(0, 0, 0, 0), strides=(1, 1), dilations=(1, 1), group=1) # EncryptedTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv2d with encrypted input and clear weight is supported return %2 """, # noqa: E501 @@ -176,10 +176,10 @@ return %2 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = w # EncryptedTensor -%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported +%0 = x # EncryptedTensor ∈ [0, 1] +%1 = w # EncryptedTensor ∈ [0, 1] +%2 = conv3d(%0, %1, [0], pads=(0, 0, 0, 0, 0, 0), strides=(1, 1, 1), dilations=(1, 1, 1), group=1) # EncryptedTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only conv3d with encrypted input and clear weight is supported return %2 """, # noqa: E501 @@ -187,22 +187,16 @@ return %2 pytest.param( lambda x, y: np.dot(x, y), {"x": "encrypted", "y": "encrypted"}, - [ - ( - np.random.randint(0, 2**2, size=(1,)), - np.random.randint(0, 2**2, size=(1,)), - ) - for _ in range(100) - ], + [([0], [0]), ([3], [3]), ([3], [0]), ([0], [3]), ([1], [1])], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = y # EncryptedTensor -%2 = dot(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported +%0 = x # EncryptedTensor ∈ [0, 3] +%1 = y # EncryptedTensor ∈ [0, 3] +%2 = dot(%0, %1) # EncryptedScalar ∈ [0, 9] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only dot product between encrypted and clear is supported return %2 """, # noqa: E501 @@ -210,15 +204,15 @@ return %2 pytest.param( lambda x: x[0], {"x": "clear"}, - [np.random.randint(0, 2**3, size=(4,)) for _ in range(100)], + [[0, 1, 2, 3], [7, 6, 5, 4]], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = %0[0] # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported +%0 = x # ClearTensor ∈ [0, 7] +%1 = %0[0] # ClearScalar ∈ [0, 7] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted indexing supported return %1 """, # noqa: E501 @@ -228,8 +222,8 @@ return %1 {"x": "encrypted", "y": "encrypted"}, [ ( - np.random.randint(0, 2**2, size=(1, 1)), - np.random.randint(0, 2**2, size=(1, 1)), + np.random.randint(0, 2**1, size=(1, 1)), + np.random.randint(0, 2**1, size=(1, 1)), ) for _ in range(100) ], @@ -238,10 +232,10 @@ return %1 Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedTensor -%1 = y # EncryptedTensor -%2 = matmul(%0, %1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported +%0 = x # EncryptedTensor ∈ [0, 1] +%1 = y # EncryptedTensor ∈ [0, 1] +%2 = matmul(%0, %1) # EncryptedTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only matrix multiplication between encrypted and clear is supported return %2 """, # noqa: E501 @@ -249,16 +243,16 @@ return %2 pytest.param( lambda x, y: x * y, {"x": "encrypted", "y": "encrypted"}, - [(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)], + [(0, 0), (7, 7), (0, 7), (7, 0)], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # EncryptedScalar -%1 = y # EncryptedScalar -%2 = multiply(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported +%0 = x # EncryptedScalar ∈ [0, 7] +%1 = y # EncryptedScalar ∈ [0, 7] +%2 = multiply(%0, %1) # EncryptedScalar ∈ [0, 49] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only multiplication between encrypted and clear is supported return %2 """, # noqa: E501 @@ -266,15 +260,15 @@ return %2 pytest.param( lambda x: -x, {"x": "clear"}, - [np.random.randint(0, 2**3) for _ in range(100)], + [0, 7], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearScalar -%1 = negative(%0) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported +%0 = x # ClearScalar ∈ [0, 7] +%1 = negative(%0) # ClearScalar ∈ [-7, 0] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted negation is supported return %1 """, # noqa: E501 @@ -288,9 +282,9 @@ return %1 Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = reshape(%0, newshape=(3, 2)) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported +%0 = x # ClearTensor ∈ [0, 7] +%1 = reshape(%0, newshape=(3, 2)) # ClearTensor ∈ [0, 7] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted reshape is supported return %1 """, # noqa: E501 @@ -304,9 +298,9 @@ return %1 Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = sum(%0) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported +%0 = x # ClearTensor ∈ [0, 1] +%1 = sum(%0) # ClearScalar ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted sum is supported return %1 """, # noqa: E501 @@ -314,16 +308,16 @@ return %1 pytest.param( lambda x: np.maximum(x, np.array([3])), {"x": "clear"}, - [np.random.randint(0, 2, size=(1,)) for _ in range(100)], + [[0], [1]], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = [3] # ClearTensor -%2 = maximum(%0, %1) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted +%0 = x # ClearTensor ∈ [0, 1] +%1 = [3] # ClearTensor ∈ [3, 3] +%2 = maximum(%0, %1) # ClearTensor ∈ [3, 3] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted return %2 """, # noqa: E501 @@ -331,15 +325,15 @@ return %2 pytest.param( lambda x: np.transpose(x), {"x": "clear"}, - [np.random.randint(0, 2, size=(3, 2)) for _ in range(100)], + [np.random.randint(0, 2, size=(3, 2)) for _ in range(10)], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = transpose(%0) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported +%0 = x # ClearTensor ∈ [0, 1] +%1 = transpose(%0) # ClearTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported return %1 """, # noqa: E501 @@ -347,15 +341,15 @@ return %1 pytest.param( lambda x: np.broadcast_to(x, shape=(3, 2)), {"x": "clear"}, - [np.random.randint(0, 2, size=(2,)) for _ in range(100)], + [np.random.randint(0, 2, size=(2,)) for _ in range(10)], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported +%0 = x # ClearTensor ∈ [0, 1] +%1 = broadcast_to(%0, shape=(3, 2)) # ClearTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported return %1 """, # noqa: E501 @@ -363,19 +357,18 @@ return %1 pytest.param( assign, {"x": "clear"}, - [np.random.randint(0, 2, size=(3,)) for _ in range(100)], + [np.random.randint(0, 2, size=(3,)) for _ in range(10)], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR -%0 = x # ClearTensor -%1 = 0 # ClearScalar -%2 = (%0[0] = %1) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported +%0 = x # ClearTensor ∈ [0, 1] +%1 = 0 # ClearScalar ∈ [0, 0] +%2 = (%0[0] = %1) # ClearTensor ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported return %2 - """, # noqa: E501 ), pytest.param( @@ -387,21 +380,21 @@ return %2 Function you are trying to compile cannot be converted to MLIR: -%0 = x # EncryptedScalar -%1 = 300 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -%3 = subgraph(%2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers +%0 = x # EncryptedScalar ∈ [200000, 200000] +%1 = 300 # ClearScalar ∈ [300, 300] +%2 = add(%0, %1) # EncryptedScalar ∈ [200300, 200300] +%3 = subgraph(%2) # EncryptedScalar ∈ [9, 9] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers return %3 Subgraphs: %3 = subgraph(%2): - %0 = 10 # ClearScalar - %1 = input # EncryptedScalar - %2 = sin(%1) # EncryptedScalar - %3 = multiply(%0, %2) # EncryptedScalar + %0 = input # EncryptedScalar + %1 = sin(%0) # EncryptedScalar + %2 = 10 # ClearScalar + %3 = multiply(%2, %1) # EncryptedScalar %4 = absolute(%3) # EncryptedScalar %5 = astype(%4, dtype=int_) # EncryptedScalar return %5 @@ -417,21 +410,21 @@ Subgraphs: Function you are trying to compile cannot be converted to MLIR: -%0 = x # EncryptedScalar -%1 = 300 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -%3 = subgraph(%2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups +%0 = x # EncryptedScalar ∈ [1024, 2047] +%1 = 300 # ClearScalar ∈ [300, 300] +%2 = add(%0, %1) # EncryptedScalar ∈ [1324, 2347] +%3 = subgraph(%2) # EncryptedScalar ∈ [-9, 9] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups return %3 Subgraphs: %3 = subgraph(%2): - %0 = 10 # ClearScalar - %1 = input # EncryptedScalar - %2 = sin(%1) # EncryptedScalar - %3 = multiply(%0, %2) # EncryptedScalar + %0 = input # EncryptedScalar + %1 = sin(%0) # EncryptedScalar + %2 = 10 # ClearScalar + %3 = multiply(%2, %1) # EncryptedScalar %4 = astype(%3, dtype=int_) # EncryptedScalar return %4 diff --git a/tests/representation/test_graph.py b/tests/representation/test_graph.py index 79aaf85e4..01ba84938 100644 --- a/tests/representation/test_graph.py +++ b/tests/representation/test_graph.py @@ -2,37 +2,175 @@ Tests of `Graph` class. """ +import os +import re + +import numpy as np import pytest import concrete.numpy as cnp +import tests + +tests_directory = os.path.dirname(tests.__file__) + + +def g(z): + """ + Example function with a tag. + """ + + with cnp.tag("def"): + a = 120 - z + b = a // 4 + return b + + +def f(x): + """ + Example function with nested tags. + """ + + with cnp.tag("abc"): + x = x * 2 + with cnp.tag("foo"): + y = x + 42 + z = np.sqrt(y).astype(np.int64) + + return g(z + 3) * 2 @pytest.mark.parametrize( - "function,inputset,expected_result", + "function,inputset,tag_filter,operation_filter,expected_result", [ pytest.param( lambda x: x + 1, range(5), + None, + None, 3, ), pytest.param( lambda x: x + 42, range(10), + None, + None, 6, ), pytest.param( lambda x: x + 42, range(50), + None, + None, 7, ), pytest.param( lambda x: x + 1.2, [1.5, 4.2], + None, + None, + -1, + ), + pytest.param( + f, + range(10), + None, + None, + 7, + ), + pytest.param( + f, + range(10), + "", + None, + 6, + ), + pytest.param( + f, + range(10), + "abc", + None, + 5, + ), + pytest.param( + f, + range(10), + ["abc", "def"], + None, + 7, + ), + pytest.param( + f, + range(10), + re.compile(".*b.*"), + None, + 6, + ), + pytest.param( + f, + range(10), + None, + "input", + 4, + ), + pytest.param( + f, + range(10), + None, + "constant", + 7, + ), + pytest.param( + f, + range(10), + None, + "subgraph", + 3, + ), + pytest.param( + f, + range(10), + None, + "add", + 6, + ), + pytest.param( + f, + range(10), + None, + ["subgraph", "add"], + 6, + ), + pytest.param( + f, + range(10), + None, + re.compile("sub.*"), + 7, + ), + pytest.param( + f, + range(10), + "abc.foo", + "add", + 6, + ), + pytest.param( + f, + range(10), + "abc", + "floor_divide", -1, ), ], ) -def test_graph_maximum_integer_bit_width(function, inputset, expected_result, helpers): +def test_graph_maximum_integer_bit_width( + function, + inputset, + tag_filter, + operation_filter, + expected_result, + helpers, +): """ Test `maximum_integer_bit_width` method of `Graph` class. """ @@ -42,6 +180,192 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he compiler = cnp.Compiler(function, {"x": "encrypted"}) graph = compiler.trace(inputset, configuration) - print(graph.format()) + assert graph.maximum_integer_bit_width(tag_filter, operation_filter) == expected_result - assert graph.maximum_integer_bit_width() == expected_result + +@pytest.mark.parametrize( + "function,inputset,tag_filter,operation_filter,expected_result", + [ + pytest.param( + lambda x: x + 42, + range(-10, 10), + None, + None, + (-10, 51), + ), + pytest.param( + lambda x: x + 1.2, + [1.5, 4.2], + None, + None, + None, + ), + pytest.param( + f, + range(10), + None, + None, + (0, 120), + ), + pytest.param( + f, + range(10), + "", + None, + (0, 54), + ), + pytest.param( + f, + range(10), + "abc", + None, + (0, 18), + ), + pytest.param( + f, + range(10), + ["abc", "def"], + None, + (0, 120), + ), + pytest.param( + f, + range(10), + re.compile(".*b.*"), + None, + (0, 60), + ), + pytest.param( + f, + range(10), + None, + "input", + (0, 9), + ), + pytest.param( + f, + range(10), + None, + "constant", + (2, 120), + ), + pytest.param( + f, + range(10), + None, + "subgraph", + (6, 7), + ), + pytest.param( + f, + range(10), + None, + "add", + (9, 60), + ), + pytest.param( + f, + range(10), + None, + ["subgraph", "add"], + (6, 60), + ), + pytest.param( + f, + range(10), + None, + re.compile("sub.*"), + (6, 111), + ), + pytest.param( + f, + range(10), + "abc.foo", + "add", + (42, 60), + ), + pytest.param( + f, + range(10), + "abc", + "floor_divide", + None, + ), + ], +) +def test_graph_integer_range( + function, + inputset, + tag_filter, + operation_filter, + expected_result, + helpers, +): + """ + Test `integer_range` method of `Graph` class. + """ + + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, {"x": "encrypted"}) + graph = compiler.trace(inputset, configuration) + + assert graph.integer_range(tag_filter, operation_filter) == expected_result + + +def test_graph_format_show_lines(helpers): + """ + Test `format` method of `Graph` class with show_lines=True. + """ + + configuration = helpers.configuration() + + compiler = cnp.Compiler(f, {"x": "encrypted"}) + graph = compiler.trace(range(10), configuration) + + # pylint: disable=line-too-long + expected = f""" + + %0 = x # EncryptedScalar ∈ [0, 9] {tests_directory}/representation/test_graph.py:324 + %1 = 2 # ClearScalar ∈ [2, 2] @ abc {tests_directory}/representation/test_graph.py:34 + %2 = multiply(%0, %1) # EncryptedScalar ∈ [0, 18] @ abc {tests_directory}/representation/test_graph.py:34 + %3 = 42 # ClearScalar ∈ [42, 42] @ abc.foo {tests_directory}/representation/test_graph.py:36 + %4 = add(%2, %3) # EncryptedScalar ∈ [42, 60] @ abc.foo {tests_directory}/representation/test_graph.py:36 + %5 = subgraph(%4) # EncryptedScalar ∈ [6, 7] @ abc {tests_directory}/representation/test_graph.py:37 + %6 = 3 # ClearScalar ∈ [3, 3] {tests_directory}/representation/test_graph.py:39 + %7 = add(%5, %6) # EncryptedScalar ∈ [9, 10] {tests_directory}/representation/test_graph.py:39 + %8 = 120 # ClearScalar ∈ [120, 120] @ def {tests_directory}/representation/test_graph.py:23 + %9 = subtract(%8, %7) # EncryptedScalar ∈ [110, 111] @ def {tests_directory}/representation/test_graph.py:23 +%10 = 4 # ClearScalar ∈ [4, 4] @ def {tests_directory}/representation/test_graph.py:24 +%11 = floor_divide(%9, %10) # EncryptedScalar ∈ [27, 27] @ def {tests_directory}/representation/test_graph.py:24 +%12 = 2 # ClearScalar ∈ [2, 2] {tests_directory}/representation/test_graph.py:39 +%13 = multiply(%11, %12) # EncryptedScalar ∈ [54, 54] {tests_directory}/representation/test_graph.py:39 +return %13 + +Subgraphs: + + %5 = subgraph(%4): + + %0 = input # EncryptedScalar @ abc.foo {tests_directory}/representation/test_graph.py:36 + %1 = sqrt(%0) # EncryptedScalar @ abc {tests_directory}/representation/test_graph.py:37 + %2 = astype(%1, dtype=int_) # EncryptedScalar @ abc {tests_directory}/representation/test_graph.py:37 + return %2 + + """ # noqa: E501 + # pylint: enable=line-too-long + + actual = graph.format(show_locations=True) + + assert ( + actual.strip() == expected.strip() + ), f""" + +Expected Output +=============== +{expected} + +Actual Output +============= +{actual} + + """