mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(opgraph): add facilities to OPGraph
- allow to construct graph from an existing networkx MultiDiGraph - add a function to remove nodes unreachable from the outputs of the graph - return the evaluated output when calling the OPGraph
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Mapping
|
||||
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -16,20 +16,79 @@ class OPGraph:
|
||||
"""Class to make work with nx graphs easier."""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
input_nodes: Mapping[int, ir.Input]
|
||||
output_nodes: Mapping[int, ir.IntermediateNode]
|
||||
input_nodes: Dict[int, ir.Input]
|
||||
output_nodes: Dict[int, ir.IntermediateNode]
|
||||
|
||||
def __init__(self, output_tracers: Iterable[BaseTracer]) -> None:
|
||||
self.output_nodes = {
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, ir.Input],
|
||||
output_nodes: Dict[int, ir.IntermediateNode],
|
||||
) -> None:
|
||||
assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
|
||||
assert all(
|
||||
isinstance(node, ir.Input) for node in input_nodes.values()
|
||||
), "Got input nodes that were not ir.Input, which is not supported"
|
||||
assert all(
|
||||
isinstance(node, ir.IntermediateNode) for node in output_nodes.values()
|
||||
), "Got output nodes which were not ir.IntermediateNode, which is not supported"
|
||||
|
||||
self.graph = graph
|
||||
self.input_nodes = input_nodes
|
||||
self.output_nodes = output_nodes
|
||||
self.prune_nodes()
|
||||
|
||||
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
|
||||
inputs = dict(enumerate(args))
|
||||
|
||||
assert len(inputs) == len(
|
||||
self.input_nodes
|
||||
), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}"
|
||||
|
||||
results = self.evaluate(inputs)
|
||||
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
|
||||
return tuple_result if len(tuple_result) > 1 else tuple_result[0]
|
||||
|
||||
@staticmethod
|
||||
def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph":
|
||||
"""Construct OPGraph from output tracers.
|
||||
|
||||
Args:
|
||||
output_tracers (Iterable[BaseTracer]): The tracers output by the function that was
|
||||
traced.
|
||||
|
||||
Returns:
|
||||
OPGraph: The resulting OPGraph.
|
||||
"""
|
||||
graph = create_graph_from_output_tracers(output_tracers)
|
||||
input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in graph.nodes()
|
||||
if len(graph.pred[node]) == 0 and isinstance(node, ir.Input)
|
||||
}
|
||||
output_nodes = {
|
||||
output_idx: tracer.traced_computation
|
||||
for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
self.graph = create_graph_from_output_tracers(output_tracers)
|
||||
self.input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in self.graph.nodes()
|
||||
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
|
||||
}
|
||||
return OPGraph(graph, input_nodes, output_nodes)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Iterable[ir.Input],
|
||||
output_nodes: Iterable[ir.IntermediateNode],
|
||||
) -> "OPGraph":
|
||||
"""Construct OPGraph from an existing networkx MultiDiGraph.
|
||||
|
||||
Args:
|
||||
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
|
||||
input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph.
|
||||
output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph.
|
||||
|
||||
Returns:
|
||||
OPGraph: The resulting OPGraph.
|
||||
"""
|
||||
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
|
||||
|
||||
def get_ordered_inputs(self) -> List[ir.Input]:
|
||||
"""Get the input nodes of the graph, ordered by their index.
|
||||
@@ -47,11 +106,11 @@ class OPGraph:
|
||||
"""
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
"""Function to evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
Args:
|
||||
inputs (Mapping[int, Any]): The inputs to the program
|
||||
inputs (Dict[int, Any]): The inputs to the program
|
||||
|
||||
Returns:
|
||||
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
|
||||
@@ -119,3 +178,18 @@ class OPGraph:
|
||||
for edge in edge_data.values():
|
||||
input_idx = edge["input_idx"]
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[0])
|
||||
|
||||
def prune_nodes(self):
|
||||
"""Function to remove unreachable nodes from outputs."""
|
||||
|
||||
current_nodes = set(self.output_nodes.values())
|
||||
useful_nodes: Set[ir.IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Set[ir.IntermediateNode] = set()
|
||||
useful_nodes.update(current_nodes)
|
||||
for node in current_nodes:
|
||||
next_nodes.update(self.graph.pred[node])
|
||||
current_nodes = next_nodes
|
||||
|
||||
useless_nodes = set(self.graph.nodes()) - useful_nodes
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
|
||||
@@ -139,6 +139,6 @@ def trace_numpy_function(
|
||||
if isinstance(output_tracers, NPTracer):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
op_graph = OPGraph(output_tracers)
|
||||
op_graph = OPGraph.from_output_tracers(output_tracers)
|
||||
|
||||
return op_graph
|
||||
|
||||
Reference in New Issue
Block a user