From 9b52ea94fb4c945cba69a6eea6b6c480b9448920 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 30 Jul 2021 15:03:59 +0200 Subject: [PATCH] dev(opgraph): add a class to ease manipulating an operator graph --- hdk/common/debugging/draw_graph.py | 26 ++++++++----- hdk/common/operator_graph.py | 59 ++++++++++++++++++++++++++++++ hdk/hnumpy/tracing.py | 20 ++++------ tests/hnumpy/test_tracing.py | 4 +- 4 files changed, 85 insertions(+), 24 deletions(-) create mode 100644 hdk/common/operator_graph.py diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index b012da781..e91623a04 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -1,9 +1,10 @@ """functions to draw the different graphs we can generate in the package, eg to debug""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import matplotlib.pyplot as plt import networkx as nx +from hdk.common.operator_graph import OPGraph from hdk.common.representation import intermediate as ir IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"} @@ -80,13 +81,15 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float def draw_graph( - graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True + graph: Union[OPGraph, nx.MultiDiGraph], + block_until_user_closes_graph: bool = True, + draw_edge_numbers: bool = True, ) -> None: """ Draw a graph Args: - graph (nx.DiGraph): The graph that we want to draw + graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw block_until_user_closes_graph (bool): if True, will wait the user to close the figure before continuing; False is useful for the CI tests draw_edge_numbers (bool): if True, add the edge number on the arrow @@ -102,6 +105,9 @@ def draw_graph( # FIXME: less variables # pylint: disable=too-many-locals + # Allow to pass either OPGraph or an nx graph, manage this here + graph = graph.graph if isinstance(graph, OPGraph) else graph + # Positions of the node pos = human_readable_layout(graph) @@ -196,17 +202,19 @@ def draw_graph( # pylint: enable=too-many-locals -def get_printable_graph(graph: nx.DiGraph) -> str: - """ - Return a string representing a graph +def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str: + """Return a string representing a graph Args: - graph (nx.DiGraph): The graph that we want to draw + graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw Returns: - a string to print or save in a file - + str: a string to print or save in a file """ + + # Allow to pass either OPGraph or an nx graph, manage this here + graph = graph.graph if isinstance(graph, OPGraph) else graph + returned_str = "" i = 0 diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py new file mode 100644 index 000000000..ccb7500a4 --- /dev/null +++ b/hdk/common/operator_graph.py @@ -0,0 +1,59 @@ +"""Code to wrap and make manipulating networkx graphs easier""" + +from typing import Any, Dict, Iterable, Mapping + +import networkx as nx + +from .representation import intermediate as ir +from .tracing import BaseTracer +from .tracing.tracing_helpers import create_graph_from_output_tracers + + +class OPGraph: + """Class to make work with nx graphs easier""" + + graph: nx.MultiDiGraph + input_nodes: Mapping[int, ir.Input] + output_nodes: Mapping[int, ir.IntermediateNode] + + def __init__(self, output_tracers: Iterable[BaseTracer]) -> None: + self.output_nodes = { + output_idx: tracer.traced_computation + for output_idx, tracer in enumerate(output_tracers) + } + self.graph = create_graph_from_output_tracers(output_tracers) + self.input_nodes = { + node.program_input_idx: node + for node in self.graph.nodes() + if len(self.graph.pred[node]) == 0 + } + + assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values())) + + graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0) + + assert set(self.output_nodes.values()) == graph_outputs + + def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]: + """Function to evaluate a graph and get intermediate values for all nodes + + Args: + inputs (Mapping[int, Any]): The inputs to the program + + Returns: + Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values + """ + node_results: Dict[ir.IntermediateNode, Any] = {} + + for node in nx.topological_sort(self.graph): + if not isinstance(node, ir.Input): + curr_inputs = {} + for pred_node in self.graph.pred[node]: + edges = self.graph.get_edge_data(pred_node, node) + for edge in edges.values(): + curr_inputs[edge["input_idx"]] = node_results[pred_node] + node_results[node] = node.evaluate(curr_inputs) + else: + node_results[node] = node.evaluate({0: inputs[node.program_input_idx]}) + + return node_results diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 8b10835e1..8072fdc18 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,15 +1,9 @@ """hnumpy tracing utilities""" from typing import Callable, Dict -import networkx as nx - from ..common.data_types import BaseValue -from ..common.tracing import ( - BaseTracer, - create_graph_from_output_tracers, - make_input_tracers, - prepare_function_parameters, -) +from ..common.operator_graph import OPGraph +from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters class NPTracer(BaseTracer): @@ -18,7 +12,7 @@ class NPTracer(BaseTracer): def trace_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] -) -> nx.MultiDiGraph: +) -> OPGraph: """Function used to trace a numpy function Args: @@ -27,8 +21,8 @@ def trace_numpy_function( function is e.g. an EncryptedValue holding a 7bits unsigned Integer Returns: - nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the - input function + OPGraph: The graph containing the ir nodes representing the computation done in the input + function """ function_parameters = prepare_function_parameters(function_to_trace, function_parameters) @@ -40,6 +34,6 @@ def trace_numpy_function( if isinstance(output_tracers, NPTracer): output_tracers = (output_tracers,) - graph = create_graph_from_output_tracers(output_tracers) + op_graph = OPGraph(output_tracers) - return graph + return op_graph diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 5383bb3b3..4dfada2d6 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -76,7 +76,7 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): else: assert False, f"unknown operation {operation}" - graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y}) + op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y}) ref_graph = nx.MultiDiGraph() @@ -108,4 +108,4 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0) ref_graph.add_edge(input_y, returned_final_node, input_idx=1) - assert test_helpers.digraphs_are_equivalent(ref_graph, graph) + assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)