dev(opgraph): add a class to ease manipulating an operator graph

This commit is contained in:
Arthur Meyre
2021-07-30 15:03:59 +02:00
parent be391ca388
commit 9b52ea94fb
4 changed files with 85 additions and 24 deletions

View File

@@ -1,9 +1,10 @@
"""functions to draw the different graphs we can generate in the package, eg to debug"""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
import matplotlib.pyplot as plt
import networkx as nx
from hdk.common.operator_graph import OPGraph
from hdk.common.representation import intermediate as ir
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
@@ -80,13 +81,15 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float
def draw_graph(
graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True
graph: Union[OPGraph, nx.MultiDiGraph],
block_until_user_closes_graph: bool = True,
draw_edge_numbers: bool = True,
) -> None:
"""
Draw a graph
Args:
graph (nx.DiGraph): The graph that we want to draw
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
block_until_user_closes_graph (bool): if True, will wait the user to
close the figure before continuing; False is useful for the CI tests
draw_edge_numbers (bool): if True, add the edge number on the arrow
@@ -102,6 +105,9 @@ def draw_graph(
# FIXME: less variables
# pylint: disable=too-many-locals
# Allow to pass either OPGraph or an nx graph, manage this here
graph = graph.graph if isinstance(graph, OPGraph) else graph
# Positions of the node
pos = human_readable_layout(graph)
@@ -196,17 +202,19 @@ def draw_graph(
# pylint: enable=too-many-locals
def get_printable_graph(graph: nx.DiGraph) -> str:
"""
Return a string representing a graph
def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
"""Return a string representing a graph
Args:
graph (nx.DiGraph): The graph that we want to draw
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
Returns:
a string to print or save in a file
str: a string to print or save in a file
"""
# Allow to pass either OPGraph or an nx graph, manage this here
graph = graph.graph if isinstance(graph, OPGraph) else graph
returned_str = ""
i = 0

View File

@@ -0,0 +1,59 @@
"""Code to wrap and make manipulating networkx graphs easier"""
from typing import Any, Dict, Iterable, Mapping
import networkx as nx
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
class OPGraph:
"""Class to make work with nx graphs easier"""
graph: nx.MultiDiGraph
input_nodes: Mapping[int, ir.Input]
output_nodes: Mapping[int, ir.IntermediateNode]
def __init__(self, output_tracers: Iterable[BaseTracer]) -> None:
self.output_nodes = {
output_idx: tracer.traced_computation
for output_idx, tracer in enumerate(output_tracers)
}
self.graph = create_graph_from_output_tracers(output_tracers)
self.input_nodes = {
node.program_input_idx: node
for node in self.graph.nodes()
if len(self.graph.pred[node]) == 0
}
assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values()))
graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0)
assert set(self.output_nodes.values()) == graph_outputs
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
"""Function to evaluate a graph and get intermediate values for all nodes
Args:
inputs (Mapping[int, Any]): The inputs to the program
Returns:
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
"""
node_results: Dict[ir.IntermediateNode, Any] = {}
for node in nx.topological_sort(self.graph):
if not isinstance(node, ir.Input):
curr_inputs = {}
for pred_node in self.graph.pred[node]:
edges = self.graph.get_edge_data(pred_node, node)
for edge in edges.values():
curr_inputs[edge["input_idx"]] = node_results[pred_node]
node_results[node] = node.evaluate(curr_inputs)
else:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
return node_results

View File

@@ -1,15 +1,9 @@
"""hnumpy tracing utilities"""
from typing import Callable, Dict
import networkx as nx
from ..common.data_types import BaseValue
from ..common.tracing import (
BaseTracer,
create_graph_from_output_tracers,
make_input_tracers,
prepare_function_parameters,
)
from ..common.operator_graph import OPGraph
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
class NPTracer(BaseTracer):
@@ -18,7 +12,7 @@ class NPTracer(BaseTracer):
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> nx.MultiDiGraph:
) -> OPGraph:
"""Function used to trace a numpy function
Args:
@@ -27,8 +21,8 @@ def trace_numpy_function(
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
Returns:
nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the
input function
OPGraph: The graph containing the ir nodes representing the computation done in the input
function
"""
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
@@ -40,6 +34,6 @@ def trace_numpy_function(
if isinstance(output_tracers, NPTracer):
output_tracers = (output_tracers,)
graph = create_graph_from_output_tracers(output_tracers)
op_graph = OPGraph(output_tracers)
return graph
return op_graph

View File

@@ -76,7 +76,7 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
else:
assert False, f"unknown operation {operation}"
graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
ref_graph = nx.MultiDiGraph()
@@ -108,4 +108,4 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
assert test_helpers.digraphs_are_equivalent(ref_graph, graph)
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)