mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
dev(opgraph): add a class to ease manipulating an operator graph
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug"""
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
from hdk.common.operator_graph import OPGraph
|
||||
from hdk.common.representation import intermediate as ir
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
|
||||
@@ -80,13 +81,15 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float
|
||||
|
||||
|
||||
def draw_graph(
|
||||
graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True
|
||||
graph: Union[OPGraph, nx.MultiDiGraph],
|
||||
block_until_user_closes_graph: bool = True,
|
||||
draw_edge_numbers: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Draw a graph
|
||||
|
||||
Args:
|
||||
graph (nx.DiGraph): The graph that we want to draw
|
||||
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
|
||||
block_until_user_closes_graph (bool): if True, will wait the user to
|
||||
close the figure before continuing; False is useful for the CI tests
|
||||
draw_edge_numbers (bool): if True, add the edge number on the arrow
|
||||
@@ -102,6 +105,9 @@ def draw_graph(
|
||||
# FIXME: less variables
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
# Allow to pass either OPGraph or an nx graph, manage this here
|
||||
graph = graph.graph if isinstance(graph, OPGraph) else graph
|
||||
|
||||
# Positions of the node
|
||||
pos = human_readable_layout(graph)
|
||||
|
||||
@@ -196,17 +202,19 @@ def draw_graph(
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
|
||||
def get_printable_graph(graph: nx.DiGraph) -> str:
|
||||
"""
|
||||
Return a string representing a graph
|
||||
def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
|
||||
"""Return a string representing a graph
|
||||
|
||||
Args:
|
||||
graph (nx.DiGraph): The graph that we want to draw
|
||||
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
|
||||
|
||||
Returns:
|
||||
a string to print or save in a file
|
||||
|
||||
str: a string to print or save in a file
|
||||
"""
|
||||
|
||||
# Allow to pass either OPGraph or an nx graph, manage this here
|
||||
graph = graph.graph if isinstance(graph, OPGraph) else graph
|
||||
|
||||
returned_str = ""
|
||||
|
||||
i = 0
|
||||
|
||||
59
hdk/common/operator_graph.py
Normal file
59
hdk/common/operator_graph.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier"""
|
||||
|
||||
from typing import Any, Dict, Iterable, Mapping
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
|
||||
|
||||
class OPGraph:
|
||||
"""Class to make work with nx graphs easier"""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
input_nodes: Mapping[int, ir.Input]
|
||||
output_nodes: Mapping[int, ir.IntermediateNode]
|
||||
|
||||
def __init__(self, output_tracers: Iterable[BaseTracer]) -> None:
|
||||
self.output_nodes = {
|
||||
output_idx: tracer.traced_computation
|
||||
for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
self.graph = create_graph_from_output_tracers(output_tracers)
|
||||
self.input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in self.graph.nodes()
|
||||
if len(self.graph.pred[node]) == 0
|
||||
}
|
||||
|
||||
assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values()))
|
||||
|
||||
graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0)
|
||||
|
||||
assert set(self.output_nodes.values()) == graph_outputs
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
"""Function to evaluate a graph and get intermediate values for all nodes
|
||||
|
||||
Args:
|
||||
inputs (Mapping[int, Any]): The inputs to the program
|
||||
|
||||
Returns:
|
||||
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
|
||||
"""
|
||||
node_results: Dict[ir.IntermediateNode, Any] = {}
|
||||
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if not isinstance(node, ir.Input):
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.pred[node]:
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
for edge in edges.values():
|
||||
curr_inputs[edge["input_idx"]] = node_results[pred_node]
|
||||
node_results[node] = node.evaluate(curr_inputs)
|
||||
else:
|
||||
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
|
||||
|
||||
return node_results
|
||||
@@ -1,15 +1,9 @@
|
||||
"""hnumpy tracing utilities"""
|
||||
from typing import Callable, Dict
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.tracing import (
|
||||
BaseTracer,
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracers,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
@@ -18,7 +12,7 @@ class NPTracer(BaseTracer):
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> nx.MultiDiGraph:
|
||||
) -> OPGraph:
|
||||
"""Function used to trace a numpy function
|
||||
|
||||
Args:
|
||||
@@ -27,8 +21,8 @@ def trace_numpy_function(
|
||||
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
|
||||
|
||||
Returns:
|
||||
nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the
|
||||
input function
|
||||
OPGraph: The graph containing the ir nodes representing the computation done in the input
|
||||
function
|
||||
"""
|
||||
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
|
||||
|
||||
@@ -40,6 +34,6 @@ def trace_numpy_function(
|
||||
if isinstance(output_tracers, NPTracer):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
graph = create_graph_from_output_tracers(output_tracers)
|
||||
op_graph = OPGraph(output_tracers)
|
||||
|
||||
return graph
|
||||
return op_graph
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
else:
|
||||
assert False, f"unknown operation {operation}"
|
||||
|
||||
graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
|
||||
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
|
||||
@@ -108,4 +108,4 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
|
||||
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, graph)
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
Reference in New Issue
Block a user