dev(opgraph): add facilities to OPGraph

- allow to construct graph from an existing networkx MultiDiGraph
- add a function to remove nodes unreachable from the outputs of the graph
- return the evaluated output when calling the OPGraph
This commit is contained in:
Arthur Meyre
2021-08-16 12:31:07 +02:00
parent 4976855c1d
commit 0eebbfcd26
2 changed files with 88 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Mapping
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
import networkx as nx
@@ -16,20 +16,79 @@ class OPGraph:
"""Class to make work with nx graphs easier."""
graph: nx.MultiDiGraph
input_nodes: Mapping[int, ir.Input]
output_nodes: Mapping[int, ir.IntermediateNode]
input_nodes: Dict[int, ir.Input]
output_nodes: Dict[int, ir.IntermediateNode]
def __init__(self, output_tracers: Iterable[BaseTracer]) -> None:
self.output_nodes = {
def __init__(
self,
graph: nx.MultiDiGraph,
input_nodes: Dict[int, ir.Input],
output_nodes: Dict[int, ir.IntermediateNode],
) -> None:
assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
assert all(
isinstance(node, ir.Input) for node in input_nodes.values()
), "Got input nodes that were not ir.Input, which is not supported"
assert all(
isinstance(node, ir.IntermediateNode) for node in output_nodes.values()
), "Got output nodes which were not ir.IntermediateNode, which is not supported"
self.graph = graph
self.input_nodes = input_nodes
self.output_nodes = output_nodes
self.prune_nodes()
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
inputs = dict(enumerate(args))
assert len(inputs) == len(
self.input_nodes
), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}"
results = self.evaluate(inputs)
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
return tuple_result if len(tuple_result) > 1 else tuple_result[0]
@staticmethod
def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph":
"""Construct OPGraph from output tracers.
Args:
output_tracers (Iterable[BaseTracer]): The tracers output by the function that was
traced.
Returns:
OPGraph: The resulting OPGraph.
"""
graph = create_graph_from_output_tracers(output_tracers)
input_nodes = {
node.program_input_idx: node
for node in graph.nodes()
if len(graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
output_nodes = {
output_idx: tracer.traced_computation
for output_idx, tracer in enumerate(output_tracers)
}
self.graph = create_graph_from_output_tracers(output_tracers)
self.input_nodes = {
node.program_input_idx: node
for node in self.graph.nodes()
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
return OPGraph(graph, input_nodes, output_nodes)
@staticmethod
def from_graph(
graph: nx.MultiDiGraph,
input_nodes: Iterable[ir.Input],
output_nodes: Iterable[ir.IntermediateNode],
) -> "OPGraph":
"""Construct OPGraph from an existing networkx MultiDiGraph.
Args:
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph.
output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph.
Returns:
OPGraph: The resulting OPGraph.
"""
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
def get_ordered_inputs(self) -> List[ir.Input]:
"""Get the input nodes of the graph, ordered by their index.
@@ -47,11 +106,11 @@ class OPGraph:
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]:
"""Function to evaluate a graph and get intermediate values for all nodes.
Args:
inputs (Mapping[int, Any]): The inputs to the program
inputs (Dict[int, Any]): The inputs to the program
Returns:
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
@@ -119,3 +178,18 @@ class OPGraph:
for edge in edge_data.values():
input_idx = edge["input_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[0])
def prune_nodes(self):
"""Function to remove unreachable nodes from outputs."""
current_nodes = set(self.output_nodes.values())
useful_nodes: Set[ir.IntermediateNode] = set()
while current_nodes:
next_nodes: Set[ir.IntermediateNode] = set()
useful_nodes.update(current_nodes)
for node in current_nodes:
next_nodes.update(self.graph.pred[node])
current_nodes = next_nodes
useless_nodes = set(self.graph.nodes()) - useful_nodes
self.graph.remove_nodes_from(useless_nodes)

View File

@@ -139,6 +139,6 @@ def trace_numpy_function(
if isinstance(output_tracers, NPTracer):
output_tracers = (output_tracers,)
op_graph = OPGraph(output_tracers)
op_graph = OPGraph.from_output_tracers(output_tracers)
return op_graph