Files
concrete/hdk/common/operator_graph.py
Arthur Meyre 0eebbfcd26 dev(opgraph): add facilities to OPGraph
- allow to construct graph from an existing networkx MultiDiGraph
- add a function to remove nodes unreachable from the outputs of the graph
- return the evaluated output when calling the OPGraph
2021-08-17 18:27:31 +02:00

196 lines
7.6 KiB
Python

"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
import networkx as nx
from .data_types.floats import Float
from .data_types.integers import make_integer_to_hold_ints
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
class OPGraph:
"""Class to make work with nx graphs easier."""
graph: nx.MultiDiGraph
input_nodes: Dict[int, ir.Input]
output_nodes: Dict[int, ir.IntermediateNode]
def __init__(
self,
graph: nx.MultiDiGraph,
input_nodes: Dict[int, ir.Input],
output_nodes: Dict[int, ir.IntermediateNode],
) -> None:
assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
assert all(
isinstance(node, ir.Input) for node in input_nodes.values()
), "Got input nodes that were not ir.Input, which is not supported"
assert all(
isinstance(node, ir.IntermediateNode) for node in output_nodes.values()
), "Got output nodes which were not ir.IntermediateNode, which is not supported"
self.graph = graph
self.input_nodes = input_nodes
self.output_nodes = output_nodes
self.prune_nodes()
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
inputs = dict(enumerate(args))
assert len(inputs) == len(
self.input_nodes
), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}"
results = self.evaluate(inputs)
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
return tuple_result if len(tuple_result) > 1 else tuple_result[0]
@staticmethod
def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph":
"""Construct OPGraph from output tracers.
Args:
output_tracers (Iterable[BaseTracer]): The tracers output by the function that was
traced.
Returns:
OPGraph: The resulting OPGraph.
"""
graph = create_graph_from_output_tracers(output_tracers)
input_nodes = {
node.program_input_idx: node
for node in graph.nodes()
if len(graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
output_nodes = {
output_idx: tracer.traced_computation
for output_idx, tracer in enumerate(output_tracers)
}
return OPGraph(graph, input_nodes, output_nodes)
@staticmethod
def from_graph(
graph: nx.MultiDiGraph,
input_nodes: Iterable[ir.Input],
output_nodes: Iterable[ir.IntermediateNode],
) -> "OPGraph":
"""Construct OPGraph from an existing networkx MultiDiGraph.
Args:
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph.
output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph.
Returns:
OPGraph: The resulting OPGraph.
"""
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
def get_ordered_inputs(self) -> List[ir.Input]:
"""Get the input nodes of the graph, ordered by their index.
Returns:
List[ir.Input]: ordered input nodes
"""
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
def get_ordered_outputs(self) -> List[ir.IntermediateNode]:
"""Get the output nodes of the graph, ordered by their index.
Returns:
List[ir.IntermediateNode]: ordered input nodes
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]:
"""Function to evaluate a graph and get intermediate values for all nodes.
Args:
inputs (Dict[int, Any]): The inputs to the program
Returns:
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
"""
node_results: Dict[ir.IntermediateNode, Any] = {}
for node in nx.topological_sort(self.graph):
if not isinstance(node, ir.Input):
curr_inputs = {}
for pred_node in self.graph.pred[node]:
edges = self.graph.get_edge_data(pred_node, node)
for edge in edges.values():
curr_inputs[edge["input_idx"]] = node_results[pred_node]
node_results[node] = node.evaluate(curr_inputs)
else:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
return node_results
def update_values_with_bounds(self, node_bounds: dict):
"""Update values with bounds.
Update nodes inputs and outputs values with data types able to hold data ranges measured
and passed in nodes_bounds
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
"""
node: ir.IntermediateNode
for node in self.graph.nodes():
current_node_bounds = node_bounds[node]
min_bound, max_bound = (
current_node_bounds["min"],
current_node_bounds["max"],
)
if not isinstance(node, ir.Input):
for output_value in node.outputs:
if isinstance(min_bound, int) and isinstance(max_bound, int):
output_value.data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
else:
output_value.data_type = Float(64)
else:
# Currently variable inputs are only allowed to be integers
assert isinstance(min_bound, int) and isinstance(max_bound, int), (
f"Inputs to a graph should be integers, got bounds that were not float, \n"
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
)
node.inputs[0].data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
node.outputs[0] = deepcopy(node.inputs[0])
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
# adding an edge
assert len(node.outputs) == 1
successors = self.graph.succ[node]
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
input_idx = edge["input_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[0])
def prune_nodes(self):
"""Function to remove unreachable nodes from outputs."""
current_nodes = set(self.output_nodes.values())
useful_nodes: Set[ir.IntermediateNode] = set()
while current_nodes:
next_nodes: Set[ir.IntermediateNode] = set()
useful_nodes.update(current_nodes)
for node in current_nodes:
next_nodes.update(self.graph.pred[node])
current_nodes = next_nodes
useless_nodes = set(self.graph.nodes()) - useful_nodes
self.graph.remove_nodes_from(useless_nodes)