Files
concrete/hdk/common/operator_graph.py
Benoit Chevallier-Mames 1771bc6e52 fix: fix #80
closes #80
2021-08-05 11:30:45 +02:00

93 lines
3.6 KiB
Python

"""Code to wrap and make manipulating networkx graphs easier"""
from copy import deepcopy
from typing import Any, Dict, Iterable, Mapping
import networkx as nx
from .data_types.integers import make_integer_to_hold_ints
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
class OPGraph:
"""Class to make work with nx graphs easier"""
graph: nx.MultiDiGraph
input_nodes: Mapping[int, ir.Input]
output_nodes: Mapping[int, ir.IntermediateNode]
def __init__(self, output_tracers: Iterable[BaseTracer]) -> None:
self.output_nodes = {
output_idx: tracer.traced_computation
for output_idx, tracer in enumerate(output_tracers)
}
self.graph = create_graph_from_output_tracers(output_tracers)
self.input_nodes = {
node.program_input_idx: node
for node in self.graph.nodes()
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
"""Function to evaluate a graph and get intermediate values for all nodes
Args:
inputs (Mapping[int, Any]): The inputs to the program
Returns:
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
"""
node_results: Dict[ir.IntermediateNode, Any] = {}
for node in nx.topological_sort(self.graph):
if not isinstance(node, ir.Input):
curr_inputs = {}
for pred_node in self.graph.pred[node]:
edges = self.graph.get_edge_data(pred_node, node)
for edge in edges.values():
curr_inputs[edge["input_idx"]] = node_results[pred_node]
node_results[node] = node.evaluate(curr_inputs)
else:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
return node_results
def update_values_with_bounds(self, node_bounds: dict):
"""Update nodes inputs and outputs values with data types able to hold data ranges measured
and passed in nodes_bounds
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
"""
node: ir.IntermediateNode
for node in self.graph.nodes():
current_node_bounds = node_bounds[node]
min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"]
if not isinstance(node, ir.Input):
for output_value in node.outputs:
output_value.data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
else:
node.inputs[0].data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
node.outputs[0] = deepcopy(node.inputs[0])
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
# adding an edge
assert len(node.outputs) == 1
successors = self.graph.succ[node]
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
input_idx = edge["input_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[0])