From 6fec590e651a35966f7409dce1657dbd35836ca8 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 10 Nov 2021 16:49:17 +0300 Subject: [PATCH] refactor(debugging): re-write graph formatting --- concrete/common/compilation/artifacts.py | 2 +- concrete/common/data_types/floats.py | 3 + concrete/common/data_types/integers.py | 3 + concrete/common/debugging/__init__.py | 2 +- concrete/common/debugging/drawing.py | 2 +- concrete/common/debugging/formatting.py | 124 +++++ concrete/common/debugging/printing.py | 146 ------ concrete/common/extensions/table.py | 3 + concrete/common/fhe_circuit.py | 2 +- concrete/common/helpers/formatting_helpers.py | 47 ++ concrete/common/mlir/utils.py | 10 +- concrete/common/operator_graph.py | 18 + concrete/common/optimization/topological.py | 3 +- .../common/representation/intermediate.py | 118 +++-- concrete/numpy/compile.py | 4 +- concrete/numpy/tracing.py | 21 +- .../bounds_measurement/test_inputset_eval.py | 32 +- tests/common/debugging/test_formatting.py | 78 +++ tests/common/debugging/test_printing.py | 94 ---- .../common/optimization/test_float_fusing.py | 136 +++--- tests/common/test_fhe_circuit.py | 2 +- tests/numpy/test_compile.py | 271 ++++++----- tests/numpy/test_compile_constant_indexing.py | 3 +- tests/numpy/test_debugging.py | 449 ------------------ tests/numpy/test_tracing.py | 64 ++- 25 files changed, 653 insertions(+), 984 deletions(-) create mode 100644 concrete/common/debugging/formatting.py delete mode 100644 concrete/common/debugging/printing.py create mode 100644 concrete/common/helpers/formatting_helpers.py create mode 100644 tests/common/debugging/test_formatting.py delete mode 100644 tests/common/debugging/test_printing.py delete mode 100644 tests/numpy/test_debugging.py diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index 35e7a8d4c..b3b04b06d 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -85,7 +85,7 @@ class CompilationArtifacts: """ drawing = draw_graph(operation_graph) - textual_representation = format_operation_graph(operation_graph, show_data_types=True) + textual_representation = format_operation_graph(operation_graph) self.drawings_of_operation_graphs[name] = drawing self.textual_representations_of_operation_graphs[name] = textual_representation diff --git a/concrete/common/data_types/floats.py b/concrete/common/data_types/floats.py index a537d28bb..1e574f31f 100644 --- a/concrete/common/data_types/floats.py +++ b/concrete/common/data_types/floats.py @@ -21,6 +21,9 @@ class Float(base.BaseDataType): def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.bit_width} bits>" + def __str__(self) -> str: + return f"float{self.bit_width}" + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.bit_width == other.bit_width diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py index ed9654972..29e33bc70 100644 --- a/concrete/common/data_types/integers.py +++ b/concrete/common/data_types/integers.py @@ -23,6 +23,9 @@ class Integer(base.BaseDataType): signed_str = "signed" if self.is_signed else "unsigned" return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>" + def __str__(self) -> str: + return f"{('int' if self.is_signed else 'uint')}{self.bit_width}" + def __eq__(self, other: object) -> bool: return ( isinstance(other, self.__class__) diff --git a/concrete/common/debugging/__init__.py b/concrete/common/debugging/__init__.py index 2532226ab..2c3d2fcac 100644 --- a/concrete/common/debugging/__init__.py +++ b/concrete/common/debugging/__init__.py @@ -1,4 +1,4 @@ """Module for debugging.""" from .custom_assert import assert_true from .drawing import draw_graph -from .printing import format_operation_graph +from .formatting import format_operation_graph diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index a40db6e6b..b7d46c20e 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -84,7 +84,7 @@ def draw_graph( attributes = { node: { - "label": node.label(), + "label": node.text_for_drawing(), "color": get_color(node, output_nodes), "penwidth": 2, # double thickness for circles "peripheries": 2 if node in output_nodes else 1, # double circle for output nodes diff --git a/concrete/common/debugging/formatting.py b/concrete/common/debugging/formatting.py new file mode 100644 index 000000000..6e1ddcc92 --- /dev/null +++ b/concrete/common/debugging/formatting.py @@ -0,0 +1,124 @@ +"""Functions to format operation graphs for debugging purposes.""" + +from typing import Dict, List, Optional, Tuple + +import networkx as nx + +from ..debugging.custom_assert import assert_true +from ..operator_graph import OPGraph +from ..representation.intermediate import IntermediateNode + + +def format_operation_graph( + opgraph: OPGraph, + maximum_constant_length: int = 25, + highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, +) -> str: + """Format an operation graph. + + Args: + opgraph (OPGraph): + the operation graph to format + + maximum_constant_length (int): + maximum length of the constant throughout the formatting + + highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None): + the dict of nodes and their corresponding messages which will be highlighted + + Returns: + str: formatted operation graph + """ + assert_true(isinstance(opgraph, OPGraph)) + + # (node, output_index) -> identifier + # e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3 + # means line for node1 is in this form (%2, %3) = node1.format(...) + id_map: Dict[Tuple[IntermediateNode, int], int] = {} + + # lines that will be merged at the end + lines: List[str] = [] + + # type information to add to each line (for alingment, this is done after lines are determined) + type_informations: List[str] = [] + + # default highlighted nodes is empty + highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {} + + # highlight information for lines, this is required because highlights are added to lines + # after their type information is added and we only have line numbers, not nodes + highlighted_lines: Dict[int, List[str]] = {} + + for node in nx.topological_sort(opgraph.graph): + # assign a unique id to outputs of node + assert_true(len(node.outputs) > 0) + for i in range(len(node.outputs)): + id_map[(node, i)] = len(id_map) + + # remember highlights of the node + if node in highlighted_nodes: + highlighted_lines[len(lines)] = highlighted_nodes[node] + + # extract predecessors and their ids + predecessors = [] + for predecessor, output_idx in opgraph.get_ordered_inputs_of(node): + predecessors.append(f"%{id_map[(predecessor, output_idx)]}") + + # start the build the line for the node + line = "" + + # add output information to the line + outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs))) + line += outputs if len(node.outputs) == 1 else f"({outputs})" + + # add node information to the line + line += " = " + line += node.text_for_formatting(predecessors, maximum_constant_length) + + # append line to list of lines + lines.append(line) + + # remember type information of the node + types = ", ".join(str(output) for output in node.outputs) + type_informations.append(types if len(node.outputs) == 1 else f"({types})") + + # align = signs + # + # e.g., + # + # %1 = ... + # %2 = ... + # ... + # %8 = ... + # %9 = ... + # %10 = ... + # %11 = ... + # ... + longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines) + for i, line in enumerate(lines): + length_before_equals_sign = len(line.split("=")[0]) + lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line + + # add type informations + longest_line_length = max(len(line) for line in lines) + for i, line in enumerate(lines): + lines[i] += " " * (longest_line_length - len(line)) + lines[i] += f" # {type_informations[i]}" + + # add highlights (this is done in reverse to keep indices consistent) + for i in reversed(range(len(lines))): + if i in highlighted_lines: + for j, message in enumerate(highlighted_lines[i]): + highlight = "^" if j == 0 else " " + lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}") + + # add return information + # (if there is a single return, it's in the form `return %id` + # (otherwise, it's in the form `return (%id1, %id2, ..., %idN)` + returns: List[str] = [] + for node in opgraph.output_nodes.values(): + outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs))) + returns.append(outputs if len(node.outputs) == 1 else f"({outputs})") + lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})") + + return "\n".join(lines) diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py deleted file mode 100644 index 795879d1e..000000000 --- a/concrete/common/debugging/printing.py +++ /dev/null @@ -1,146 +0,0 @@ -"""functions to print the different graphs we can generate in the package, eg to debug.""" - -from typing import Any, Dict, List, Optional - -import networkx as nx - -from ..debugging.custom_assert import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import ( - Constant, - GenericFunction, - IndexConstant, - Input, - IntermediateNode, -) - - -def output_data_type_to_string(node): - """Return the datatypes of the outputs of the node. - - Args: - node: a graph node - - Returns: - str: a string representing the datatypes of the outputs of the node - - """ - return ", ".join([str(o) for o in node.outputs]) - - -def shorten_a_constant(constant_data: str): - """Return a constant (if small) or an extra of the constant (if too large). - - Args: - constant (str): The constant we want to shorten - - Returns: - str: a string to represent the constant - """ - - content = str(constant_data).replace("\n", "") - # if content is longer than 25 chars, only show the first and the last 10 chars of it - # 25 is selected using the spaces available before data type information - short_content = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content - return short_content - - -def format_operation_graph( - opgraph: OPGraph, - show_data_types: bool = False, - highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, -) -> str: - """Return a string representing a graph. - - Args: - opgraph (OPGraph): The graph that we want to draw - show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see - their width. Defaults to False. - highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes - which will be highlighted and their corresponding messages. Defaults to None. - - Returns: - str: a string to print or save in a file - """ - assert_true(isinstance(opgraph, OPGraph)) - - highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {} - - list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values()) - graph = opgraph.graph - - returned_str = "" - - i = 0 - map_table: Dict[Any, int] = {} - - for node in nx.topological_sort(graph): - - # TODO: #640 - # This code doesn't work with more than a single output. For more outputs, - # we would need to change the way the destination are created: currently, - # they only are done by incrementing i - assert_true(len(node.outputs) == 1) - - if isinstance(node, Input): - what_to_print = node.input_name - elif isinstance(node, Constant): - to_show = shorten_a_constant(node.constant_data) - what_to_print = f"Constant({to_show})" - else: - - base_name = node.__class__.__name__ - - if isinstance(node, GenericFunction): - base_name = node.op_name - - what_to_print = base_name + "(" - - # Find all the names of the current predecessors of the node - list_of_arg_name = [] - - for pred, index_list in graph.pred[node].items(): - for index in index_list.values(): - # Remark that we keep the index of the predecessor and its - # name, to print sources in the right order, which is - # important for eg non commutative operations - list_of_arg_name += [(index["input_idx"], str(map_table[pred]))] - - # Some checks, because the previous algorithm is not clear - assert_true(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))) - list_of_arg_name.sort() - assert_true([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))) - - prefix_to_add_to_what_to_print = "" - suffix_to_add_to_what_to_print = "" - - # Then, just print the predecessors in the right order - what_to_print += prefix_to_add_to_what_to_print - what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) - what_to_print += suffix_to_add_to_what_to_print - what_to_print += ( - f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else "" - ) - what_to_print += ")" - - # This code doesn't work with more than a single output - new_line = f"%{i} = {what_to_print}" - - # Manage datatypes - if show_data_types: - new_line = f"{new_line: <50s} # {output_data_type_to_string(node)}" - - returned_str += f"{new_line}\n" - - if node in highlighted_nodes: - new_line_len = len(new_line) - message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node]) - returned_str += f"{'^' * new_line_len} {message}\n" - - map_table[node] = i - i += 1 - - return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs]) - returned_str += f"return({return_part})\n" - - return returned_str diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index 0380c188e..e87fb2a05 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -30,6 +30,9 @@ class LookupTable: self.table = table self.output_dtype = make_integer_to_hold(table, force_signed=False) + def __repr__(self): + return str(list(self.table)) + def __getitem__(self, key: Union[int, Iterable, BaseTracer]): # if a tracer is used for indexing, # we need to create an `GenericFunction` node diff --git a/concrete/common/fhe_circuit.py b/concrete/common/fhe_circuit.py index fbfd20353..b96be2f24 100644 --- a/concrete/common/fhe_circuit.py +++ b/concrete/common/fhe_circuit.py @@ -21,7 +21,7 @@ class FHECircuit: self.engine = engine def __str__(self): - return format_operation_graph(self.opgraph, show_data_types=True) + return format_operation_graph(self.opgraph) def draw( self, diff --git a/concrete/common/helpers/formatting_helpers.py b/concrete/common/helpers/formatting_helpers.py new file mode 100644 index 000000000..6121e444b --- /dev/null +++ b/concrete/common/helpers/formatting_helpers.py @@ -0,0 +1,47 @@ +"""Helpers for formatting functionality.""" + +from typing import Any, Dict, Hashable + +import numpy + +from ..debugging.custom_assert import assert_true + +SPECIAL_OBJECT_MAPPING: Dict[Any, str] = { + numpy.float32: "float32", + numpy.float64: "float64", + numpy.int8: "int8", + numpy.int16: "int16", + numpy.int32: "int32", + numpy.int64: "int64", + numpy.uint8: "uint8", + numpy.uint16: "uint16", + numpy.uint32: "uint32", + numpy.uint64: "uint64", +} + + +def format_constant(constant: Any, maximum_length: int = 45) -> str: + """Format a constant. + + Args: + constant (Any): the constant to format + maximum_length (int): maximum length of the resulting string + + Returns: + str: the formatted constant + """ + + if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING: + return SPECIAL_OBJECT_MAPPING[constant] + + # maximum_length should not be smaller than 7 characters because + # the constant will be formatted to `x ... y` + # where x and y are part of the constant and they are at least 1 character + assert_true(maximum_length >= 7) + + content = str(constant).replace("\n", "") + if len(content) > maximum_length: + from_start = (maximum_length - 5) // 2 + from_end = (maximum_length - 5) - from_start + content = f"{content[:from_start]} ... {content[-from_end:]}" + return content diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 4ddb6db09..c464d9d43 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -96,9 +96,7 @@ def check_node_compatibility_with_mlir( # e.g., `np.absolute is not supported for the time being` return f"{node.op_name} is not supported for the time being" else: - return ( - f"{node.op_name} of kind {node.op_kind.value} is not supported for the time being" - ) + return f"{node.op_name} is not supported for the time being" elif isinstance(node, intermediate.Dot): # constraints for dot product assert_true(len(inputs) == 2) @@ -206,10 +204,8 @@ def update_bit_width_for_mlir(op_graph: OPGraph): raise RuntimeError( f"max_bit_width of some nodes is too high for the current version of " f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) " - f"which is not compatible with:\n" - + format_operation_graph( - op_graph, show_data_types=True, highlighted_nodes=offending_nodes - ) + f"which is not compatible with:\n\n" + + format_operation_graph(op_graph, highlighted_nodes=offending_nodes) ) _set_all_bit_width(op_graph, max_bit_width) diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index a6c27a46c..ebd062209 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -131,6 +131,24 @@ class OPGraph: idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values()) return [idx_to_pred[i] for i in range(len(idx_to_pred))] + def get_ordered_inputs_of(self, node: IntermediateNode) -> List[Tuple[IntermediateNode, int]]: + """Get node inputs ordered by their indices. + + Args: + node (IntermediateNode): the node for which we want the ordered inputs + + Returns: + List[Tuple[IntermediateNode, int]]: the ordered list of inputs + """ + + idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {} + for pred in self.graph.pred[node]: + edge_data = self.graph.get_edge_data(pred, node) + idx_to_inp.update( + (data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values() + ) + return [idx_to_inp[i] for i in range(len(idx_to_inp))] + def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]: """Evaluate a graph and get intermediate values for all nodes. diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 3e284f289..272039604 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -133,10 +133,9 @@ def convert_float_subgraph_to_fused_node( printable_graph = format_operation_graph( float_subgraph_as_op_graph, - show_data_types=True, highlighted_nodes=node_with_issues_for_fusing, ) - message = f"The following subgraph is not fusable:\n{printable_graph}" + message = f"The following subgraph is not fusable:\n\n{printable_graph}" logger.warning(message) return None diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index fd0bdeffb..6c6362122 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -16,6 +16,7 @@ from ..data_types.dtypes_helpers import ( from ..data_types.integers import Integer from ..debugging.custom_assert import assert_true from ..helpers import indexing_helpers +from ..helpers.formatting_helpers import format_constant from ..helpers.python_helpers import catch, update_and_return_dict from ..values import ( BaseValue, @@ -64,6 +65,27 @@ class IntermediateNode(ABC): self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] + def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: + """Get the formatted node (used in formatting opgraph). + + Args: + predecessors (List[str]): predecessor names to this node + _maximum_constant_length (int): desired maximum constant length + + Returns: + str: the formatted node + """ + + return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})" + + @abstractmethod + def text_for_drawing(self) -> str: + """Get the label of the node (used in drawing opgraph). + + Returns: + str: the label of the node + """ + @abstractmethod def evaluate(self, inputs: Dict[int, Any]) -> Any: """Simulate what the represented computation would output for the given inputs. @@ -93,15 +115,6 @@ class IntermediateNode(ABC): """ return cls.n_in() > 1 - @abstractmethod - def label(self) -> str: - """Get the label of the node. - - Returns: - str: the label of the node - - """ - class Add(IntermediateNode): """Addition between two values.""" @@ -110,12 +123,12 @@ class Add(IntermediateNode): __init__ = IntermediateNode._init_binary + def text_for_drawing(self) -> str: + return "+" + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] + inputs[1] - def label(self) -> str: - return "+" - class Sub(IntermediateNode): """Subtraction between two values.""" @@ -124,12 +137,12 @@ class Sub(IntermediateNode): __init__ = IntermediateNode._init_binary + def text_for_drawing(self) -> str: + return "-" + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] - inputs[1] - def label(self) -> str: - return "-" - class Mul(IntermediateNode): """Multiplication between two values.""" @@ -138,12 +151,12 @@ class Mul(IntermediateNode): __init__ = IntermediateNode._init_binary + def text_for_drawing(self) -> str: + return "*" + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] * inputs[1] - def label(self) -> str: - return "*" - class Input(IntermediateNode): """Node representing an input of the program.""" @@ -164,12 +177,16 @@ class Input(IntermediateNode): self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] + def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: + assert_true(len(predecessors) == 0) + return self.input_name + + def text_for_drawing(self) -> str: + return self.input_name + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] - def label(self) -> str: - return self.input_name - class Constant(IntermediateNode): """Node representing a constant of the program.""" @@ -191,6 +208,13 @@ class Constant(IntermediateNode): self._constant_data = constant_data self.outputs = [base_value_class(is_encrypted=False)] + def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str: + assert_true(len(predecessors) == 0) + return format_constant(self.constant_data, maximum_constant_length) + + def text_for_drawing(self) -> str: + return format_constant(self.constant_data) + def evaluate(self, inputs: Dict[int, Any]) -> Any: return self.constant_data @@ -203,9 +227,6 @@ class Constant(IntermediateNode): """ return self._constant_data - def label(self) -> str: - return str(self.constant_data) - class IndexConstant(IntermediateNode): """Node representing a constant indexing in the program. @@ -242,19 +263,17 @@ class IndexConstant(IntermediateNode): else ClearTensor(output_dtype, output_shape) ] - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0][self.index] - - def label(self) -> str: - """Label of the node to show during drawings. - - It can be used for some other places after `"value"` below is replaced by `""`. - This note will no longer be necessary after #707 is addressed. - - """ + def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: + assert_true(len(predecessors) == 1) elements = [indexing_helpers.format_indexing_element(element) for element in self.index] index = ", ".join(elements) - return f"value[{index}]" + return f"{predecessors[0]}[{index}]" + + def text_for_drawing(self) -> str: + return self.text_for_formatting(["value"], 0) # 0 is unused + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + return inputs[0][self.index] def flood_replace_none_values(table: list): @@ -335,15 +354,26 @@ class GenericFunction(IntermediateNode): self.op_name = op_name if op_name is not None else self.__class__.__name__ + def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str: + all_args = deepcopy(predecessors) + + all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args) + all_args.extend( + f"{name}={format_constant(value, maximum_constant_length)}" + for name, value in self.op_kwargs.items() + ) + + return f"{self.op_name}({', '.join(all_args)})" + + def text_for_drawing(self) -> str: + return self.op_name + def evaluate(self, inputs: Dict[int, Any]) -> Any: # This is the continuation of the mypy bug workaround assert self.arbitrary_func is not None ordered_inputs = [inputs[idx] for idx in range(len(inputs))] return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs) - def label(self) -> str: - return self.op_name - def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]: """Get the table for the current input value of this GenericFunction. @@ -466,14 +496,14 @@ class Dot(IntermediateNode): self.outputs = [output_scalar_value(output_dtype)] self.evaluation_function = delegate_evaluation_function + def text_for_drawing(self) -> str: + return "dot" + def evaluate(self, inputs: Dict[int, Any]) -> Any: # This is the continuation of the mypy bug workaround assert self.evaluation_function is not None return self.evaluation_function(inputs[0], inputs[1]) - def label(self) -> str: - return "dot" - class MatMul(IntermediateNode): """Return the node representing a matrix multiplication.""" @@ -513,8 +543,8 @@ class MatMul(IntermediateNode): self.outputs = [output_value] + def text_for_drawing(self) -> str: + return "matmul" + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] @ inputs[1] - - def label(self) -> str: - return "@" diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index f8f7e6f61..2414c88cb 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -262,9 +262,7 @@ def prepare_op_graph_for_mlir(op_graph): if offending_nodes is not None: raise RuntimeError( "function you are trying to compile isn't supported for MLIR lowering\n\n" - + format_operation_graph( - op_graph, show_data_types=True, highlighted_nodes=offending_nodes - ) + + format_operation_graph(op_graph, highlighted_nodes=offending_nodes) ) # Update bit_width for MLIR diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index f395448e8..c3cdb9843 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -104,7 +104,7 @@ class NPTracer(BaseTracer): output_value=generic_function_output_value, op_kind="TLU", op_kwargs={"dtype": normalized_numpy_dtype.type}, - op_name=f"astype({normalized_numpy_dtype})", + op_name="astype", ) output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0) return output_tracer @@ -320,7 +320,7 @@ class NPTracer(BaseTracer): output_value=generic_function_output_value, op_kind="Memory", op_kwargs=deepcopy(kwargs), - op_name="np.transpose", + op_name="transpose", op_attributes={"fusable": transpose_is_fusable}, ) output_tracer = self.__class__( @@ -367,7 +367,7 @@ class NPTracer(BaseTracer): output_value=generic_function_output_value, op_kind="Memory", op_kwargs=deepcopy(kwargs), - op_name="np.ravel", + op_name="ravel", op_attributes={"fusable": ravel_is_fusable}, ) output_tracer = self.__class__( @@ -405,12 +405,9 @@ class NPTracer(BaseTracer): assert_true(isinstance(first_arg_output, TensorValue)) first_arg_output = cast(TensorValue, first_arg_output) - newshape = deepcopy(arg1) - - if isinstance(newshape, int): - # Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is - # numpy.reshape(x, (170,)) - newshape = (newshape,) + # Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, + # while classical form is numpy.reshape(x, (170,)) + newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,) # Check shape compatibility assert_true( @@ -435,7 +432,7 @@ class NPTracer(BaseTracer): output_value=generic_function_output_value, op_kind="Memory", op_kwargs={"newshape": newshape}, - op_name="np.reshape", + op_name="reshape", op_attributes={"fusable": reshape_is_fusable}, ) output_tracer = self.__class__( @@ -575,7 +572,7 @@ def _get_unary_fun(function: numpy.ufunc): # dynamically # pylint: disable=protected-access return lambda *input_tracers, **kwargs: NPTracer._unary_operator( - function, f"np.{function.__name__}", *input_tracers, **kwargs + function, f"{function.__name__}", *input_tracers, **kwargs ) # pylint: enable=protected-access @@ -587,7 +584,7 @@ def _get_binary_fun(function: numpy.ufunc): # dynamically # pylint: disable=protected-access return lambda *input_tracers, **kwargs: NPTracer._binary_operator( - function, f"np.{function.__name__}", *input_tracers, **kwargs + function, f"{function.__name__}", *input_tracers, **kwargs ) # pylint: enable=protected-access diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index 085ea865c..ddf1fa975 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -328,11 +328,11 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys): captured = capsys.readouterr() assert ( captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor, shape=(3,)> for parameter `x` " - "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "(expected EncryptedTensor for parameter `x` " + "but got EncryptedTensor which is not compatible)\n" "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor, shape=(3,)> for parameter `y` " - "but got ClearTensor, shape=(4,)> which is not compatible)\n" + "(expected ClearTensor for parameter `y` " + "but got ClearTensor which is not compatible)\n" ) @@ -365,14 +365,14 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys): captured = capsys.readouterr() assert ( captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor, shape=(3,)> for parameter `x` " - "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "(expected EncryptedTensor for parameter `x` " + "but got EncryptedTensor which is not compatible)\n" "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor, shape=(3,)> for parameter `y` " - "but got ClearTensor, shape=(4,)> which is not compatible)\n" + "(expected ClearTensor for parameter `y` " + "but got ClearTensor which is not compatible)\n" "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor, shape=(3,)> for parameter `y` " - "but got ClearTensor, shape=(3,)> which is not compatible)\n" + "(expected ClearTensor for parameter `y` " + "but got ClearTensor which is not compatible)\n" ) @@ -436,14 +436,14 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys) captured = capsys.readouterr() assert ( captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor, shape=(3,)> for parameter `x` " - "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "(expected EncryptedTensor for parameter `x` " + "but got EncryptedTensor which is not compatible)\n" "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor, shape=(3,)> for parameter `y` " - "but got ClearTensor, shape=(4,)> which is not compatible)\n" + "(expected ClearTensor for parameter `y` " + "but got ClearTensor which is not compatible)\n" "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor, shape=(3,)> for parameter `y` " - "but got ClearTensor, shape=(3,)> which is not compatible)\n" + "(expected ClearTensor for parameter `y` " + "but got ClearTensor which is not compatible)\n" ) diff --git a/tests/common/debugging/test_formatting.py b/tests/common/debugging/test_formatting.py new file mode 100644 index 000000000..340478017 --- /dev/null +++ b/tests/common/debugging/test_formatting.py @@ -0,0 +1,78 @@ +"""Test file for formatting""" + +from concrete.common.data_types.integers import Integer +from concrete.common.debugging import format_operation_graph +from concrete.common.values import EncryptedScalar +from concrete.numpy.compile import compile_numpy_function_into_op_graph + + +def test_format_operation_graph_with_multiple_edges(default_compilation_configuration): + """Test format_operation_graph with multiple edges""" + + def function(x): + return x + x + + opgraph = compile_numpy_function_into_op_graph( + function, + {"x": EncryptedScalar(Integer(4, True))}, + [(i,) for i in range(0, 10)], + default_compilation_configuration, + ) + + formatted_graph = format_operation_graph(opgraph) + assert ( + formatted_graph + == """ + +%0 = x # EncryptedScalar +%1 = add(%0, %0) # EncryptedScalar +return %1 + +""".strip() + ) + + +def test_format_operation_graph_with_offending_nodes(default_compilation_configuration): + """Test format_operation_graph with offending nodes""" + + def function(x): + return x + 42 + + opgraph = compile_numpy_function_into_op_graph( + function, + {"x": EncryptedScalar(Integer(7, True))}, + [(i,) for i in range(-5, 5)], + default_compilation_configuration, + ) + + highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]} + formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip() + assert ( + formatted_graph + == """ + +%0 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo +%1 = 42 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +return %2 + +""".strip() + ) + + highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]} + formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip() + assert ( + formatted_graph + == """ + +%0 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo +%1 = 42 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar + baz +return %2 + +""".strip() + ) diff --git a/tests/common/debugging/test_printing.py b/tests/common/debugging/test_printing.py deleted file mode 100644 index 9e7bc0cae..000000000 --- a/tests/common/debugging/test_printing.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Test file for printing""" - -from concrete.common.data_types.integers import Integer -from concrete.common.debugging import format_operation_graph -from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function_into_op_graph - - -def test_format_operation_graph_with_offending_nodes(default_compilation_configuration): - """Test format_operation_graph with offending nodes""" - - def function(x): - return x + 42 - - opgraph = compile_numpy_function_into_op_graph( - function, - {"x": EncryptedScalar(Integer(7, True))}, - [(i,) for i in range(-5, 5)], - default_compilation_configuration, - ) - - highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]} - - without_types = format_operation_graph( - opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes - ).strip() - with_types = format_operation_graph( - opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes - ).strip() - - assert ( - without_types - == """ - -%0 = x -^^^^^^ foo -%1 = Constant(42) -%2 = Add(%0, %1) -return(%2) - -""".strip() - ) - - assert ( - with_types - == """ - -%0 = x # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo -%1 = Constant(42) # ClearScalar> -%2 = Add(%0, %1) # EncryptedScalar> -return(%2) - -""".strip() - ) - - highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]} - - without_types = format_operation_graph( - opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes - ).strip() - with_types = format_operation_graph( - opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes - ).strip() - - assert ( - without_types - == """ - -%0 = x -^^^^^^ foo -%1 = Constant(42) -%2 = Add(%0, %1) -^^^^^^^^^^^^^^^^ bar - baz -return(%2) - -""".strip() - ) - - assert ( - with_types - == """ - -%0 = x # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo -%1 = Constant(42) # ClearScalar> -%2 = Add(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar - baz -return(%2) - -""".strip() - ) diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 1d7175dc6..554077b4c 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -157,88 +157,118 @@ def get_func_params_int32(func, scalar=True): no_fuse_unhandled, False, get_func_params_int32(no_fuse_unhandled), - """The following subgraph is not fusable: -%0 = x # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) -%1 = Constant(0.7) # ClearScalar> -%2 = y # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) -%3 = Constant(1.3) # ClearScalar> -%4 = Add(%0, %1) # EncryptedScalar> -%5 = Add(%2, %3) # EncryptedScalar> -%6 = Add(%4, %5) # EncryptedScalar> -%7 = astype(int32)(%6) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs -return(%7)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) +%1 = 0.7 # ClearScalar +%2 = y # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) +%3 = 1.3 # ClearScalar +%4 = add(%0, %1) # EncryptedScalar +%5 = add(%2, %3) # EncryptedScalar +%6 = add(%4, %5) # EncryptedScalar +%7 = astype(%6, dtype=int32) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs +return %7 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_unhandled", ), pytest.param( no_fuse_dot, False, {"x": EncryptedTensor(Integer(32, True), (10,))}, - """The following subgraph is not fusable: -%0 = x # EncryptedTensor, shape=(10,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,) -%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor, shape=(10,)> -%2 = Dot(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) -%3 = astype(int32)(%2) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) -return(%3)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,) +%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor +%2 = dot(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) +%3 = astype(%2, dtype=int32) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) +return %3 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_dot", ), pytest.param( ravel_cases, False, {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """The following subgraph is not fusable: -%0 = x # EncryptedTensor, shape=(10, 20)> -%1 = astype(float64)(%0) # EncryptedTensor, shape=(10, 20)> -%2 = np.ravel(%1) # EncryptedTensor, shape=(200,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(int32)(%2) # EncryptedTensor, shape=(200,)> -return(%3)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = x # EncryptedTensor +%1 = astype(%0, dtype=float64) # EncryptedTensor +%2 = ravel(%1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable +%3 = astype(%2, dtype=int32) # EncryptedTensor +return %3 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_explicitely_ravel", ), pytest.param( transpose_cases, False, {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """The following subgraph is not fusable: -%0 = x # EncryptedTensor, shape=(10, 20)> -%1 = astype(float64)(%0) # EncryptedTensor, shape=(10, 20)> -%2 = np.transpose(%1) # EncryptedTensor, shape=(20, 10)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(int32)(%2) # EncryptedTensor, shape=(20, 10)> -return(%3)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = x # EncryptedTensor +%1 = astype(%0, dtype=float64) # EncryptedTensor +%2 = transpose(%1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable +%3 = astype(%2, dtype=int32) # EncryptedTensor +return %3 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_explicitely_transpose", ), pytest.param( lambda x: reshape_cases(x, (20, 10)), False, {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """The following subgraph is not fusable: -%0 = x # EncryptedTensor, shape=(10, 20)> -%1 = astype(float64)(%0) # EncryptedTensor, shape=(10, 20)> -%2 = np.reshape(%1) # EncryptedTensor, shape=(20, 10)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(int32)(%2) # EncryptedTensor, shape=(20, 10)> -return(%3)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = x # EncryptedTensor +%1 = astype(%0, dtype=float64) # EncryptedTensor +%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable +%3 = astype(%2, dtype=int32) # EncryptedTensor +return %3 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_explicitely_reshape", ), pytest.param( no_fuse_big_constant_3_10_10, False, {"x": EncryptedTensor(Integer(32, True), (10, 10))}, - """The following subgraph is not fusable: -%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor, shape=(3, 10, 10)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10) -%1 = x # EncryptedTensor, shape=(10, 10)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10) -%2 = astype(float64)(%1) # EncryptedTensor, shape=(10, 10)> -%3 = Add(%2, %0) # EncryptedTensor, shape=(3, 10, 10)> -%4 = astype(int32)(%3) # EncryptedTensor, shape=(3, 10, 10)> -return(%4)""", # noqa: E501 # pylint: disable=line-too-long + """ + +The following subgraph is not fusable: + +%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10) +%1 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10) +%2 = astype(%1, dtype=float64) # EncryptedTensor +%3 = add(%2, %0) # EncryptedTensor +%4 = astype(%3, dtype=int32) # EncryptedTensor +return %4 + + """.strip(), # noqa: E501 # pylint: disable=line-too-long id="no_fuse_big_constant_3_10_10", ), pytest.param( @@ -322,7 +352,7 @@ def test_fuse_float_operations( else: assert fused_num_nodes == orig_num_nodes captured = capfd.readouterr() - assert warning_message in remove_color_codes(captured.err) + assert warning_message in (output := remove_color_codes(captured.err)), output for input_ in [0, 2, 42, 44]: inputs = () diff --git a/tests/common/test_fhe_circuit.py b/tests/common/test_fhe_circuit.py index 70315cfda..348667885 100644 --- a/tests/common/test_fhe_circuit.py +++ b/tests/common/test_fhe_circuit.py @@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration): inputset = [(i,) for i in range(2 ** 3)] circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True) + assert str(circuit) == format_operation_graph(circuit.opgraph) def test_circuit_draw(default_compilation_configuration): diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b2f92c30d..6dc1fe5b7 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -495,7 +495,7 @@ def test_compile_function_multiple_outputs( # when we have the converter, we can check the MLIR draw_graph(op_graph, show=False) - str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph) print(f"\n{str_of_the_graph}\n") @@ -960,7 +960,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration): default_compilation_configuration, ) - str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph) print(f"\n{str_of_the_graph}\n") @@ -991,14 +991,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura [(i,) for i in range(8)], ( """ + function you are trying to compile isn't supported for MLIR lowering -%0 = Constant(1) # ClearScalar> -%1 = x # EncryptedScalar> -%2 = Sub(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported -return(%2) -""".lstrip() # noqa: E501 +%0 = 1 # ClearScalar +%1 = x # EncryptedScalar +%2 = sub(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported +return %2 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1021,16 +1023,18 @@ return(%2) ], ( """ + function you are trying to compile isn't supported for MLIR lowering -%0 = x # EncryptedTensor, shape=(1,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported -%1 = y # EncryptedTensor, shape=(1,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported -%2 = Dot(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported -return(%2) -""".lstrip() # noqa: E501 +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported +%1 = y # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported +%2 = dot(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported +return %2 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1039,14 +1043,16 @@ return(%2) [(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)], ( """ + function you are trying to compile isn't supported for MLIR lowering -%0 = x # EncryptedTensor, shape=(2, 2)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported -%1 = IndexConstant(%0[0]) # EncryptedTensor, shape=(2,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being -return(%1) -""".lstrip() # noqa: E501 +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported +%1 = %0[0] # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being +return %1 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1055,28 +1061,30 @@ return(%1) [(numpy.array(i), numpy.array(i)) for i in range(10)], ( """ + function you are trying to compile isn't supported for MLIR lowering -%0 = Constant(1.5) # ClearScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%1 = x # EncryptedScalar> -%2 = Constant(2.8) # ClearScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%3 = y # EncryptedScalar> -%4 = Constant(9.3) # ClearScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%5 = Add(%1, %2) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported -%6 = Add(%3, %4) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported -%7 = Sub(%5, %6) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported -%8 = Mul(%7, %0) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported -%9 = astype(int32)(%8) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported without fusing -return(%9) -""".lstrip() # noqa: E501 +%0 = 1.5 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported +%1 = x # EncryptedScalar +%2 = 2.8 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported +%3 = y # EncryptedScalar +%4 = 9.3 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported +%5 = add(%1, %2) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported +%6 = add(%3, %4) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported +%7 = sub(%5, %6) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported +%8 = mul(%7, %0) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported +%9 = astype(%8, dtype=int32) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype is not supported without fusing +return %9 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1084,13 +1092,17 @@ return(%9) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], ( - "function you are trying to compile isn't supported for MLIR lowering\n" - "\n" - "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 - "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 - "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 - "return(%2)\n" + """ + +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = [[1 1 1] [1 1 1]] # ClearTensor +%2 = matmul(%0, %1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being +return %2 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1098,13 +1110,15 @@ return(%9) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], ( - "function you are trying to compile isn't supported for MLIR lowering\n" - "\n" - "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 - "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 - "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 - "return(%2)\n" + """ +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = [[1 1 1] [1 1 1]] # ClearTensor +%2 = matmul(%0, %1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being +return %2 + """.strip() # noqa: E501 ), ), pytest.param( @@ -1112,13 +1126,17 @@ return(%9) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], ( - "function you are trying to compile isn't supported for MLIR lowering\n" - "\n" - "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 - "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 - "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 - "return(%2)\n" + """ + +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = [[1 1 1] [1 1 1]] # ClearTensor +%2 = matmul(%0, %1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being +return %2 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1127,13 +1145,15 @@ return(%9) [(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)], ( """ + function you are trying to compile isn't supported for MLIR lowering -%0 = x # EncryptedTensor, shape=(3, 2)> -%1 = MultiTLU(%0) # EncryptedTensor, shape=(3, 2)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being -return(%1) -""".lstrip() # noqa: E501 +%0 = x # EncryptedTensor +%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being +return %1 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1141,13 +1161,16 @@ return(%1) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], ( - """function you are trying to compile isn't supported for MLIR lowering + """ -%0 = x # EncryptedTensor, shape=(3, 2)> -%1 = np.transpose(%0) # EncryptedTensor, shape=(2, 3)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.transpose of kind Memory is not supported for the time being -return(%1) -""" # noqa: E501 +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = transpose(%0) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being +return %1 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1155,13 +1178,16 @@ return(%1) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], ( - """function you are trying to compile isn't supported for MLIR lowering + """ -%0 = x # EncryptedTensor, shape=(3, 2)> -%1 = np.ravel(%0) # EncryptedTensor, shape=(6,)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.ravel of kind Memory is not supported for the time being -return(%1) -""" # noqa: E501 +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = ravel(%0) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being +return %1 + + """.strip() # noqa: E501 ), ), pytest.param( @@ -1169,13 +1195,16 @@ return(%1) {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))}, [(numpy.random.randint(0, 2 ** 3, size=(3, 4)),) for i in range(10)], ( - """function you are trying to compile isn't supported for MLIR lowering + """ -%0 = x # EncryptedTensor, shape=(3, 4)> -%1 = np.reshape(%0) # EncryptedTensor, shape=(2, 6)> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.reshape of kind Memory is not supported for the time being -return(%1) -""" # noqa: E501 +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being +return %1 + + """.strip() # noqa: E501 ), ), ], @@ -1217,19 +1246,21 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration) ) except RuntimeError as error: match = """ + function you are trying to compile isn't supported for MLIR lowering -%0 = y # EncryptedScalar> -%1 = Constant(10) # ClearScalar> -%2 = x # EncryptedScalar> -%3 = np.negative(%2) # EncryptedScalar> -%4 = Mul(%3, %1) # EncryptedScalar> -%5 = np.absolute(%4) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being -%6 = astype(int32)(%5) # EncryptedScalar> -%7 = Add(%6, %0) # EncryptedScalar> -return(%7) -""".lstrip() # noqa: E501 # pylint: disable=line-too-long +%0 = y # EncryptedScalar +%1 = 10 # ClearScalar +%2 = x # EncryptedScalar +%3 = negative(%2) # EncryptedScalar +%4 = mul(%3, %1) # EncryptedScalar +%5 = absolute(%4) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being +%6 = astype(%5, dtype=int32) # EncryptedScalar +%7 = add(%6, %0) # EncryptedScalar +return %7 + + """.strip() # noqa: E501 # pylint: disable=line-too-long assert str(error) == match raise @@ -1270,13 +1301,14 @@ def test_small_inputset_treat_warnings_as_errors(): (4,), # Remark that, when you do the dot of tensors of 4 values between 0 and 3, # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits - "%0 = x " - "# EncryptedTensor, shape=(4,)>" - "\n%1 = y " - "# EncryptedTensor, shape=(4,)>" - "\n%2 = Dot(%0, %1) " - "# EncryptedScalar>" - "\nreturn(%2)\n", + """ + +%0 = x # EncryptedTensor +%1 = y # EncryptedTensor +%2 = dot(%0, %1) # EncryptedScalar +return %2 + + """.strip(), ), ], ) @@ -1303,7 +1335,7 @@ def test_compile_function_with_dot( data_gen(max_for_ij, repeat), default_compilation_configuration, ) - str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" f"==================\nExpected \n{ref_graph_str}" @@ -1373,13 +1405,16 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): assert ( str(excinfo.value) == """ + max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with: -%0 = x # EncryptedScalar> -%1 = y # EncryptedScalar> -%2 = Add(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being -return(%2) -""".lstrip() # noqa: E501 # pylint: disable=line-too-long + +%0 = x # EncryptedScalar +%1 = y # EncryptedScalar +%2 = add(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being +return %2 + + """.strip() # noqa: E501 # pylint: disable=line-too-long ) # Just ok @@ -1418,14 +1453,16 @@ def test_failure_for_signed_output(default_compilation_configuration): assert ( str(excinfo.value) == """ + function you are trying to compile isn't supported for MLIR lowering -%0 = x # EncryptedScalar> -%1 = Constant(-3) # ClearScalar> -%2 = Add(%0, %1) # EncryptedScalar> -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported -return(%2) -""".lstrip() # noqa: E501 # pylint: disable=line-too-long +%0 = x # EncryptedScalar +%1 = -3 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported +return %2 + +""".strip() # noqa: E501 # pylint: disable=line-too-long ) diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index 78acd99ed..fed935ada 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -368,8 +368,7 @@ def test_constant_indexing( EncryptedScalar(UnsignedInteger(1)), lambda x: x[0], TypeError, - "Only tensors can be indexed " - "but you tried to index EncryptedScalar>", + "Only tensors can be indexed but you tried to index EncryptedScalar", ), pytest.param( EncryptedTensor(UnsignedInteger(1), shape=(3,)), diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py deleted file mode 100644 index 5933971dc..000000000 --- a/tests/numpy/test_debugging.py +++ /dev/null @@ -1,449 +0,0 @@ -"""Test file for debugging functions""" - -import numpy -import pytest - -from concrete.common.data_types.integers import Integer -from concrete.common.debugging import draw_graph, format_operation_graph -from concrete.common.extensions.table import LookupTable -from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor -from concrete.numpy import tracing - -LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11]) -LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0]) - - -def issue_130_a(x, y): - """Test case derived from issue #130""" - # pylint: disable=unused-argument - intermediate = x + 1 - return (intermediate, intermediate) - # pylint: enable=unused-argument - - -def issue_130_b(x, y): - """Test case derived from issue #130""" - # pylint: disable=unused-argument - intermediate = x - 1 - return (intermediate, intermediate) - # pylint: enable=unused-argument - - -def issue_130_c(x, y): - """Test case derived from issue #130""" - # pylint: disable=unused-argument - intermediate = 1 - x - return (intermediate, intermediate) - # pylint: enable=unused-argument - - -@pytest.mark.parametrize( - "lambda_f,ref_graph_str", - [ - (lambda x, y: x + y, "%0 = x\n%1 = y\n%2 = Add(%0, %1)\nreturn(%2)\n"), - (lambda x, y: x - y, "%0 = x\n%1 = y\n%2 = Sub(%0, %1)\nreturn(%2)\n"), - (lambda x, y: x + x, "%0 = x\n%1 = Add(%0, %0)\nreturn(%1)\n"), - ( - lambda x, y: x + x - y * y * y + x, - "%0 = x\n%1 = y\n%2 = Add(%0, %0)\n%3 = Mul(%1, %1)" - "\n%4 = Mul(%3, %1)\n%5 = Sub(%2, %4)\n%6 = Add(%5, %0)\nreturn(%6)\n", - ), - (lambda x, y: x + 1, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"), - (lambda x, y: 1 + x, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"), - (lambda x, y: (-1) + x, "%0 = x\n%1 = Constant(-1)\n%2 = Add(%0, %1)\nreturn(%2)\n"), - (lambda x, y: 3 * x, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"), - (lambda x, y: x * 3, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"), - (lambda x, y: x * (-3), "%0 = x\n%1 = Constant(-3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"), - (lambda x, y: x - 11, "%0 = x\n%1 = Constant(11)\n%2 = Sub(%0, %1)\nreturn(%2)\n"), - (lambda x, y: 11 - x, "%0 = Constant(11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"), - (lambda x, y: (-11) - x, "%0 = Constant(-11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"), - ( - lambda x, y: x + 13 - y * (-21) * y + 44, - "%0 = Constant(44)" - "\n%1 = x" - "\n%2 = Constant(13)" - "\n%3 = y" - "\n%4 = Constant(-21)" - "\n%5 = Add(%1, %2)" - "\n%6 = Mul(%3, %4)" - "\n%7 = Mul(%6, %3)" - "\n%8 = Sub(%5, %7)" - "\n%9 = Add(%8, %0)" - "\nreturn(%9)\n", - ), - # Multiple outputs - ( - lambda x, y: (x + 1, x + y + 2), - "%0 = x" - "\n%1 = Constant(1)" - "\n%2 = Constant(2)" - "\n%3 = y" - "\n%4 = Add(%0, %1)" - "\n%5 = Add(%0, %3)" - "\n%6 = Add(%5, %2)" - "\nreturn(%4, %6)\n", - ), - ( - lambda x, y: (y, x), - "%0 = y\n%1 = x\nreturn(%0, %1)\n", - ), - ( - lambda x, y: (x, x + 1), - "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%0, %2)\n", - ), - ( - lambda x, y: (x + 1, x + 1), - "%0 = x" - "\n%1 = Constant(1)" - "\n%2 = Constant(1)" - "\n%3 = Add(%0, %1)" - "\n%4 = Add(%0, %2)" - "\nreturn(%3, %4)\n", - ), - ( - issue_130_a, - "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2, %2)\n", - ), - ( - issue_130_b, - "%0 = x\n%1 = Constant(1)\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n", - ), - ( - issue_130_c, - "%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n", - ), - ( - lambda x, y: numpy.arctan2(x, 42) + y, - """%0 = y -%1 = x -%2 = Constant(42) -%3 = np.arctan2(%1, %2) -%4 = Add(%3, %0) -return(%4) -""", - ), - ( - lambda x, y: numpy.arctan2(43, x) + y, - """%0 = y -%1 = Constant(43) -%2 = x -%3 = np.arctan2(%1, %2) -%4 = Add(%3, %0) -return(%4) -""", - ), - ], -) -@pytest.mark.parametrize( - "x_y", - [ - pytest.param( - ( - EncryptedScalar(Integer(64, is_signed=False)), - EncryptedScalar(Integer(64, is_signed=False)), - ), - id="Encrypted uint", - ), - pytest.param( - ( - EncryptedScalar(Integer(64, is_signed=False)), - ClearScalar(Integer(64, is_signed=False)), - ), - id="Clear uint", - ), - ], -) -def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y): - "Test format_operation_graph and draw_graph" - x, y = x_y - graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) - - draw_graph(graph, show=False) - - str_of_the_graph = format_operation_graph(graph) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -@pytest.mark.parametrize( - "lambda_f,params,ref_graph_str", - [ - ( - lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x\n%1 = TLU(%0)\nreturn(%1)\n", - ), - ( - lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x\n%1 = Constant(4)\n%2 = Add(%0, %1)\n%3 = TLU(%2)\nreturn(%3)\n", - ), - ], -) -def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str): - "Test format_operation_graph and draw_graph on graphs with direct table lookup" - graph = tracing.trace_numpy_function(lambda_f, params) - - draw_graph(graph, show=False) - - str_of_the_graph = format_operation_graph(graph) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -@pytest.mark.parametrize( - "lambda_f,params,ref_graph_str", - [ - ( - lambda x, y: numpy.dot(x, y), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), - "y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), - }, - "%0 = x\n%1 = y\n%2 = Dot(%0, %1)\nreturn(%2)\n", - ), - ], -) -def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): - "Test format_operation_graph and draw_graph on graphs with dot" - graph = tracing.trace_numpy_function(lambda_f, params) - - draw_graph(graph, show=False) - - str_of_the_graph = format_operation_graph(graph) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -# pylint: disable=line-too-long -@pytest.mark.parametrize( - "lambda_f,params,ref_graph_str", - [ - ( - lambda x: numpy.transpose(x), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), - }, - """ -%0 = x # EncryptedTensor, shape=(3, 5)> -%1 = np.transpose(%0) # EncryptedTensor, shape=(5, 3)> -return(%1) -""".lstrip(), # noqa: E501 - ), - ( - lambda x: numpy.ravel(x), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), - }, - """ -%0 = x # EncryptedTensor, shape=(3, 5)> -%1 = np.ravel(%0) # EncryptedTensor, shape=(15,)> -return(%1) -""".lstrip(), # noqa: E501 - ), - ( - lambda x: numpy.reshape(x, (5, 3)), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), - }, - """ -%0 = x # EncryptedTensor, shape=(3, 5)> -%1 = np.reshape(%0) # EncryptedTensor, shape=(5, 3)> -return(%1) -""".lstrip(), # noqa: E501 - ), - ( - lambda x: numpy.reshape(x, (170,)), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)), - }, - """ -%0 = x # EncryptedTensor, shape=(17, 10)> -%1 = np.reshape(%0) # EncryptedTensor, shape=(170,)> -return(%1) -""".lstrip(), # noqa: E501 - ), - ( - lambda x: numpy.reshape(x, (170)), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)), - }, - """ -%0 = x # EncryptedTensor, shape=(17, 10)> -%1 = np.reshape(%0) # EncryptedTensor, shape=(170,)> -return(%1) -""".lstrip(), # noqa: E501 - ), - ], -) -def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str): - "Test format_operation_graph and draw_graph on graphs with generic function" - graph = tracing.trace_numpy_function(lambda_f, params) - - draw_graph(graph, show=False) - - str_of_the_graph = format_operation_graph(graph, show_data_types=True) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -# pylint: enable=line-too-long - - -# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b -# returning 23b), since they are replaced later by the real bitwidths computed on the -# inputset -@pytest.mark.parametrize( - "lambda_f,x_y,ref_graph_str", - [ - ( - lambda x, y: x + y, - ( - EncryptedScalar(Integer(64, is_signed=False)), - EncryptedScalar(Integer(32, is_signed=True)), - ), - "%0 = x " - "# EncryptedScalar>" - "\n%1 = y " - " # EncryptedScalar>" - "\n%2 = Add(%0, %1) " - " # EncryptedScalar>" - "\nreturn(%2)\n", - ), - ( - lambda x, y: x * y, - ( - EncryptedScalar(Integer(17, is_signed=False)), - EncryptedScalar(Integer(23, is_signed=False)), - ), - "%0 = x " - "# EncryptedScalar>" - "\n%1 = y " - "# EncryptedScalar>" - "\n%2 = Mul(%0, %1) " - "# EncryptedScalar>" - "\nreturn(%2)\n", - ), - ], -) -def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str): - """Test format_operation_graph with show_data_types""" - x, y = x_y - graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) - - str_of_the_graph = format_operation_graph(graph, show_data_types=True) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -@pytest.mark.parametrize( - "lambda_f,params,ref_graph_str", - [ - ( - lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " - "# EncryptedScalar>" - "\n%1 = TLU(%0) " - "# EncryptedScalar>" - "\nreturn(%1)\n", - ), - ( - lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " - "# EncryptedScalar>" - "\n%1 = Constant(4) " - "# ClearScalar>" - "\n%2 = Add(%0, %1) " - "# EncryptedScalar>" - "\n%3 = TLU(%2) " - "# EncryptedScalar>" - "\nreturn(%3)\n", - ), - ( - lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]], - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " - "# EncryptedScalar>" - "\n%1 = Constant(4) " - "# ClearScalar>" - "\n%2 = Add(%0, %1) " - "# EncryptedScalar>" - "\n%3 = TLU(%2) " - "# EncryptedScalar>" - "\n%4 = TLU(%3) " - "# EncryptedScalar>" - "\nreturn(%4)\n", - ), - ], -) -def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str): - """Test format_operation_graph with show_data_types on graphs with direct table lookup""" - graph = tracing.trace_numpy_function(lambda_f, params) - - draw_graph(graph, show=False) - - str_of_the_graph = format_operation_graph(graph, show_data_types=True) - - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -def test_numpy_long_constant(): - "Test format_operation_graph with long constant" - - def all_explicit_operations(x): - intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10)) - intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10)) - intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate) - intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate) - return intermediate - - op_graph = tracing.trace_numpy_function( - all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))} - ) - - expected = """ -%0 = Constant([[100 101 ... 198 199]]) # ClearTensor, shape=(10, 10)> -%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor, shape=(1, 10)> -%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor, shape=(1, 10)> -%3 = x # EncryptedTensor, shape=(10, 10)> -%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor, shape=(10, 10)> -%5 = Add(%3, %4) # EncryptedTensor, shape=(10, 10)> -%6 = Sub(%5, %2) # EncryptedTensor, shape=(10, 10)> -%7 = np.arctan2(%1, %6) # EncryptedTensor, shape=(10, 10)> -%8 = np.arctan2(%0, %7) # EncryptedTensor, shape=(10, 10)> -return(%8) -""".lstrip() # noqa: E501 - - str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) - - assert str_of_the_graph == expected, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{expected}" - f"==================\n" - ) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 381fdffaf..c6373a405 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -188,24 +188,22 @@ def test_numpy_tracing_tensors(): all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))} ) - expected = """ -%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> -%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> -%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> -%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> -%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> -%5 = x # EncryptedTensor, shape=(2, 2)> -%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> -%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> -%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> -%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> -%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> -return(%12) -""".lstrip() # noqa: E501 + expected = """ %0 = [[2 1] [1 2]] # ClearTensor + %1 = [[1 2] [2 1]] # ClearTensor + %2 = [[10 20] [30 40]] # ClearTensor + %3 = [[100 200] [300 400]] # ClearTensor + %4 = [[5 6] [7 8]] # ClearTensor + %5 = x # EncryptedTensor + %6 = [[1 2] [3 4]] # ClearTensor + %7 = add(%5, %6) # EncryptedTensor + %8 = add(%4, %7) # EncryptedTensor + %9 = sub(%3, %8) # EncryptedTensor +%10 = sub(%9, %2) # EncryptedTensor +%11 = mul(%10, %1) # EncryptedTensor +%12 = mul(%0, %11) # EncryptedTensor +return %12""" # noqa: E501 - assert format_operation_graph(op_graph, show_data_types=True) == expected + assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph) def test_numpy_explicit_tracing_tensors(): @@ -227,24 +225,22 @@ def test_numpy_explicit_tracing_tensors(): all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))} ) - expected = """ -%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> -%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> -%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> -%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> -%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> -%5 = x # EncryptedTensor, shape=(2, 2)> -%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> -%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> -%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> -%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> -%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> -return(%12) -""".lstrip() # noqa: E501 + expected = """ %0 = [[2 1] [1 2]] # ClearTensor + %1 = [[1 2] [2 1]] # ClearTensor + %2 = [[10 20] [30 40]] # ClearTensor + %3 = [[100 200] [300 400]] # ClearTensor + %4 = [[5 6] [7 8]] # ClearTensor + %5 = x # EncryptedTensor + %6 = [[1 2] [3 4]] # ClearTensor + %7 = add(%5, %6) # EncryptedTensor + %8 = add(%4, %7) # EncryptedTensor + %9 = sub(%3, %8) # EncryptedTensor +%10 = sub(%9, %2) # EncryptedTensor +%11 = mul(%10, %1) # EncryptedTensor +%12 = mul(%0, %11) # EncryptedTensor +return %12""" # noqa: E501 - assert format_operation_graph(op_graph, show_data_types=True) == expected + assert format_operation_graph(op_graph) == expected @pytest.mark.parametrize(