diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index 14bd2b539..eecc4ac85 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -1,7 +1,7 @@ """Code to wrap and make manipulating networkx graphs easier.""" from copy import deepcopy -from typing import Any, Dict, Iterable, List, Mapping +from typing import Any, Dict, Iterable, List, Set, Tuple, Union import networkx as nx @@ -16,20 +16,79 @@ class OPGraph: """Class to make work with nx graphs easier.""" graph: nx.MultiDiGraph - input_nodes: Mapping[int, ir.Input] - output_nodes: Mapping[int, ir.IntermediateNode] + input_nodes: Dict[int, ir.Input] + output_nodes: Dict[int, ir.IntermediateNode] - def __init__(self, output_tracers: Iterable[BaseTracer]) -> None: - self.output_nodes = { + def __init__( + self, + graph: nx.MultiDiGraph, + input_nodes: Dict[int, ir.Input], + output_nodes: Dict[int, ir.IntermediateNode], + ) -> None: + assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported" + assert all( + isinstance(node, ir.Input) for node in input_nodes.values() + ), "Got input nodes that were not ir.Input, which is not supported" + assert all( + isinstance(node, ir.IntermediateNode) for node in output_nodes.values() + ), "Got output nodes which were not ir.IntermediateNode, which is not supported" + + self.graph = graph + self.input_nodes = input_nodes + self.output_nodes = output_nodes + self.prune_nodes() + + def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]: + inputs = dict(enumerate(args)) + + assert len(inputs) == len( + self.input_nodes + ), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}" + + results = self.evaluate(inputs) + tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs()) + return tuple_result if len(tuple_result) > 1 else tuple_result[0] + + @staticmethod + def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph": + """Construct OPGraph from output tracers. + + Args: + output_tracers (Iterable[BaseTracer]): The tracers output by the function that was + traced. + + Returns: + OPGraph: The resulting OPGraph. + """ + graph = create_graph_from_output_tracers(output_tracers) + input_nodes = { + node.program_input_idx: node + for node in graph.nodes() + if len(graph.pred[node]) == 0 and isinstance(node, ir.Input) + } + output_nodes = { output_idx: tracer.traced_computation for output_idx, tracer in enumerate(output_tracers) } - self.graph = create_graph_from_output_tracers(output_tracers) - self.input_nodes = { - node.program_input_idx: node - for node in self.graph.nodes() - if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input) - } + return OPGraph(graph, input_nodes, output_nodes) + + @staticmethod + def from_graph( + graph: nx.MultiDiGraph, + input_nodes: Iterable[ir.Input], + output_nodes: Iterable[ir.IntermediateNode], + ) -> "OPGraph": + """Construct OPGraph from an existing networkx MultiDiGraph. + + Args: + graph (nx.MultiDiGraph): The networkx MultiDiGraph to use. + input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph. + output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph. + + Returns: + OPGraph: The resulting OPGraph. + """ + return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes))) def get_ordered_inputs(self) -> List[ir.Input]: """Get the input nodes of the graph, ordered by their index. @@ -47,11 +106,11 @@ class OPGraph: """ return [self.output_nodes[idx] for idx in range(len(self.output_nodes))] - def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]: + def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]: """Function to evaluate a graph and get intermediate values for all nodes. Args: - inputs (Mapping[int, Any]): The inputs to the program + inputs (Dict[int, Any]): The inputs to the program Returns: Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values @@ -119,3 +178,18 @@ class OPGraph: for edge in edge_data.values(): input_idx = edge["input_idx"] succ.inputs[input_idx] = deepcopy(node.outputs[0]) + + def prune_nodes(self): + """Function to remove unreachable nodes from outputs.""" + + current_nodes = set(self.output_nodes.values()) + useful_nodes: Set[ir.IntermediateNode] = set() + while current_nodes: + next_nodes: Set[ir.IntermediateNode] = set() + useful_nodes.update(current_nodes) + for node in current_nodes: + next_nodes.update(self.graph.pred[node]) + current_nodes = next_nodes + + useless_nodes = set(self.graph.nodes()) - useful_nodes + self.graph.remove_nodes_from(useless_nodes) diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 245bbf641..a06125399 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -139,6 +139,6 @@ def trace_numpy_function( if isinstance(output_tracers, NPTracer): output_tracers = (output_tracers,) - op_graph = OPGraph(output_tracers) + op_graph = OPGraph.from_output_tracers(output_tracers) return op_graph