From 46a018fd2104140a904cf2e4b3b2479c027098ec Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 16 Nov 2021 11:30:54 +0300 Subject: [PATCH] refactor: use a consistent short name for operation graph --- concrete/common/debugging/drawing.py | 8 +++---- concrete/common/debugging/formatting.py | 12 +++++----- concrete/common/fhe_circuit.py | 10 ++++---- concrete/common/mlir/graph_converter.py | 24 +++++++++---------- concrete/common/mlir/node_converter.py | 8 +++---- .../common/representation/intermediate.py | 4 ++-- concrete/numpy/np_mlir_converter.py | 6 ++--- tests/common/debugging/test_formatting.py | 14 +++++------ tests/common/test_fhe_circuit.py | 6 ++--- tests/numpy/test_compile_constant_indexing.py | 12 +++++----- 10 files changed, 52 insertions(+), 52 deletions(-) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index b7d46c20e..8737ba19d 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -52,7 +52,7 @@ del _missing_nodes_in_mapping def draw_graph( - opgraph: OPGraph, + op_graph: OPGraph, show: bool = False, vertical: bool = True, save_to: Optional[Path] = None, @@ -60,7 +60,7 @@ def draw_graph( """Draws operation graphs and optionally saves/shows the drawing. Args: - opgraph (OPGraph): the graph to be drawn and optionally saved/shown + op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown show (bool): if set to True, the drawing will be shown using matplotlib vertical (bool): if set to True, the orientation will be vertical save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else @@ -79,8 +79,8 @@ def draw_graph( value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return) return value_to_return - graph = opgraph.graph - output_nodes = set(opgraph.output_nodes.values()) + graph = op_graph.graph + output_nodes = set(op_graph.output_nodes.values()) attributes = { node: { diff --git a/concrete/common/debugging/formatting.py b/concrete/common/debugging/formatting.py index 6e1ddcc92..d0f649e17 100644 --- a/concrete/common/debugging/formatting.py +++ b/concrete/common/debugging/formatting.py @@ -10,14 +10,14 @@ from ..representation.intermediate import IntermediateNode def format_operation_graph( - opgraph: OPGraph, + op_graph: OPGraph, maximum_constant_length: int = 25, highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, ) -> str: """Format an operation graph. Args: - opgraph (OPGraph): + op_graph (OPGraph): the operation graph to format maximum_constant_length (int): @@ -29,7 +29,7 @@ def format_operation_graph( Returns: str: formatted operation graph """ - assert_true(isinstance(opgraph, OPGraph)) + assert_true(isinstance(op_graph, OPGraph)) # (node, output_index) -> identifier # e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3 @@ -49,7 +49,7 @@ def format_operation_graph( # after their type information is added and we only have line numbers, not nodes highlighted_lines: Dict[int, List[str]] = {} - for node in nx.topological_sort(opgraph.graph): + for node in nx.topological_sort(op_graph.graph): # assign a unique id to outputs of node assert_true(len(node.outputs) > 0) for i in range(len(node.outputs)): @@ -61,7 +61,7 @@ def format_operation_graph( # extract predecessors and their ids predecessors = [] - for predecessor, output_idx in opgraph.get_ordered_inputs_of(node): + for predecessor, output_idx in op_graph.get_ordered_inputs_of(node): predecessors.append(f"%{id_map[(predecessor, output_idx)]}") # start the build the line for the node @@ -116,7 +116,7 @@ def format_operation_graph( # (if there is a single return, it's in the form `return %id` # (otherwise, it's in the form `return (%id1, %id2, ..., %idN)` returns: List[str] = [] - for node in opgraph.output_nodes.values(): + for node in op_graph.output_nodes.values(): outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs))) returns.append(outputs if len(node.outputs) == 1 else f"({outputs})") lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})") diff --git a/concrete/common/fhe_circuit.py b/concrete/common/fhe_circuit.py index b96be2f24..6296d65be 100644 --- a/concrete/common/fhe_circuit.py +++ b/concrete/common/fhe_circuit.py @@ -13,15 +13,15 @@ from .operator_graph import OPGraph class FHECircuit: """Class which is the result of compilation.""" - opgraph: OPGraph + op_graph: OPGraph engine: CompilerEngine - def __init__(self, opgraph: OPGraph, engine: CompilerEngine): - self.opgraph = opgraph + def __init__(self, op_graph: OPGraph, engine: CompilerEngine): + self.op_graph = op_graph self.engine = engine def __str__(self): - return format_operation_graph(self.opgraph) + return format_operation_graph(self.op_graph) def draw( self, @@ -42,7 +42,7 @@ class FHECircuit: """ - return draw_graph(self.opgraph, show, vertical, save_to) + return draw_graph(self.op_graph, show, vertical, save_to) def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]: """Encrypt, evaluate, and decrypt the inputs on the circuit. diff --git a/concrete/common/mlir/graph_converter.py b/concrete/common/mlir/graph_converter.py index 2cf51b23c..9b714999a 100644 --- a/concrete/common/mlir/graph_converter.py +++ b/concrete/common/mlir/graph_converter.py @@ -23,17 +23,17 @@ from .node_converter import IntermediateNodeConverter class OPGraphConverter(ABC): """Converter of OPGraph to MLIR.""" - def convert(self, opgraph: OPGraph) -> str: + def convert(self, op_graph: OPGraph) -> str: """Convert an operation graph to its corresponding MLIR representation. Args: - opgraph (OPGraph): the operation graph to be converted + op_graph (OPGraph): the operation graph to be converted Returns: - str: textual MLIR representation corresponding to opgraph + str: textual MLIR representation corresponding to given operation graph """ - additional_conversion_info = self._generate_additional_info_dict(opgraph) + additional_conversion_info = self._generate_additional_info_dict(op_graph) with Context() as ctx, Location.unknown(): zamalang.register_dialects(ctx) @@ -42,25 +42,25 @@ class OPGraphConverter(ABC): with InsertionPoint(module.body): parameters = [ value_to_mlir_type(ctx, input_node.outputs[0]) - for input_node in opgraph.get_ordered_inputs() + for input_node in op_graph.get_ordered_inputs() ] @builtin.FuncOp.from_py_func(*parameters) def main(*arg): ir_to_mlir = {} - for arg_num, node in opgraph.input_nodes.items(): + for arg_num, node in op_graph.input_nodes.items(): ir_to_mlir[node] = arg[arg_num] - for node in nx.topological_sort(opgraph.graph): + for node in nx.topological_sort(op_graph.graph): if isinstance(node, Input): continue - preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)] - node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds) + preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)] + node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds) ir_to_mlir[node] = node_converter.convert(additional_conversion_info) results = ( - ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs() + ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs() ) return results @@ -68,11 +68,11 @@ class OPGraphConverter(ABC): @staticmethod @abstractmethod - def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]: + def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: """Generate additional conversion info dict for the MLIR converter. Args: - opgraph (OPGraph): the operation graph from which the additional info will be generated + op_graph (OPGraph): the operation graph from which the additional info will be generated Returns: Dict[str, Any]: dict of additional conversion info diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index ff1cc1cd5..3b10fb23b 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -41,7 +41,7 @@ class IntermediateNodeConverter: """Converter of IntermediateNode to MLIR.""" ctx: Context - opgraph: OPGraph + op_graph: OPGraph node: IntermediateNode preds: List[OpResult] @@ -50,10 +50,10 @@ class IntermediateNodeConverter: one_of_the_inputs_is_a_tensor: bool def __init__( - self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult] + self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult] ): self.ctx = ctx - self.opgraph = opgraph + self.op_graph = op_graph self.node = node self.preds = preds @@ -219,7 +219,7 @@ class IntermediateNodeConverter: variable_input_indices = [ idx - for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node)) + for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node)) if not isinstance(inp, Constant) ] if len(variable_input_indices) != 1: # pragma: no cover diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 8c49ee8de..ab5446f0a 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -66,7 +66,7 @@ class IntermediateNode(ABC): self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: - """Get the formatted node (used in formatting opgraph). + """Get the formatted node (used in formatting operation graphs). Args: predecessors (List[str]): predecessor names to this node @@ -80,7 +80,7 @@ class IntermediateNode(ABC): @abstractmethod def text_for_drawing(self) -> str: - """Get the label of the node (used in drawing opgraph). + """Get the label of the node (used in drawing operation graphs). Returns: str: the label of the node diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py index da85bf836..f937e96cb 100644 --- a/concrete/numpy/np_mlir_converter.py +++ b/concrete/numpy/np_mlir_converter.py @@ -71,14 +71,14 @@ class NPMLIRConverter(OPGraphConverter): """Numpy-specific MLIR converter.""" @staticmethod - def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]: + def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: additional_conversion_info = {} # Disable numpy warnings during conversion to avoid issues during TLU generation with numpy.errstate(all="ignore"): additional_conversion_info["tables"] = { - node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node)) - for node in opgraph.graph.nodes() + node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node)) + for node in op_graph.graph.nodes() if isinstance(node, GenericFunction) } diff --git a/tests/common/debugging/test_formatting.py b/tests/common/debugging/test_formatting.py index 340478017..9ef7bf105 100644 --- a/tests/common/debugging/test_formatting.py +++ b/tests/common/debugging/test_formatting.py @@ -12,14 +12,14 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur def function(x): return x + x - opgraph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph( function, {"x": EncryptedScalar(Integer(4, True))}, [(i,) for i in range(0, 10)], default_compilation_configuration, ) - formatted_graph = format_operation_graph(opgraph) + formatted_graph = format_operation_graph(op_graph) assert ( formatted_graph == """ @@ -38,15 +38,15 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu def function(x): return x + 42 - opgraph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph( function, {"x": EncryptedScalar(Integer(7, True))}, [(i,) for i in range(-5, 5)], default_compilation_configuration, ) - highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]} - formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip() + highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]} + formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip() assert ( formatted_graph == """ @@ -60,8 +60,8 @@ return %2 """.strip() ) - highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]} - formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip() + highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]} + formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip() assert ( formatted_graph == """ diff --git a/tests/common/test_fhe_circuit.py b/tests/common/test_fhe_circuit.py index 348667885..8355b0fea 100644 --- a/tests/common/test_fhe_circuit.py +++ b/tests/common/test_fhe_circuit.py @@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration): inputset = [(i,) for i in range(2 ** 3)] circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - assert str(circuit) == format_operation_graph(circuit.opgraph) + assert str(circuit) == format_operation_graph(circuit.op_graph) def test_circuit_draw(default_compilation_configuration): @@ -31,8 +31,8 @@ def test_circuit_draw(default_compilation_configuration): inputset = [(i,) for i in range(2 ** 3)] circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph)) - assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False)) + assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph)) + assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False)) def test_circuit_run(default_compilation_configuration): diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index fed935ada..197a0b1f0 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -347,15 +347,15 @@ def test_constant_indexing( for _ in range(10) ] - opgraph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph( function_with_indexing, {"x": input_value}, inputset, default_compilation_configuration, ) - assert len(opgraph.output_nodes) == 1 - output_node = opgraph.output_nodes[0] + assert len(op_graph.output_nodes) == 1 + output_node = op_graph.output_nodes[0] assert len(output_node.outputs) == 1 assert output_value == output_node.outputs[0] @@ -525,15 +525,15 @@ def test_constant_indexing_with_numpy_integers( for _ in range(10) ] - opgraph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph( function_with_indexing, {"x": input_value}, inputset, default_compilation_configuration, ) - assert len(opgraph.output_nodes) == 1 - output_node = opgraph.output_nodes[0] + assert len(op_graph.output_nodes) == 1 + output_node = op_graph.output_nodes[0] assert len(output_node.outputs) == 1 assert output_value == output_node.outputs[0]