mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
078c8dc8f1
commit
1ee8195af0
@@ -1,5 +1,5 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug"""
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
@@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = {
|
||||
ir.Add: "red",
|
||||
ir.Sub: "yellow",
|
||||
ir.Mul: "green",
|
||||
"output": "magenta",
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +88,7 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float
|
||||
|
||||
|
||||
def draw_graph(
|
||||
graph: Union[OPGraph, nx.MultiDiGraph],
|
||||
opgraph: OPGraph,
|
||||
block_until_user_closes_graph: bool = True,
|
||||
draw_edge_numbers: bool = True,
|
||||
) -> None:
|
||||
@@ -95,7 +96,7 @@ def draw_graph(
|
||||
Draw a graph
|
||||
|
||||
Args:
|
||||
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
|
||||
graph (OPGraph): The graph that we want to draw
|
||||
block_until_user_closes_graph (bool): if True, will wait the user to
|
||||
close the figure before continuing; False is useful for the CI tests
|
||||
draw_edge_numbers (bool): if True, add the edge number on the arrow
|
||||
@@ -111,14 +112,20 @@ def draw_graph(
|
||||
# FIXME: less variables
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
# Allow to pass either OPGraph or an nx graph, manage this here
|
||||
graph = graph.graph if isinstance(graph, OPGraph) else graph
|
||||
assert isinstance(opgraph, OPGraph)
|
||||
set_of_nodes_which_are_outputs = set(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
# Positions of the node
|
||||
pos = human_readable_layout(graph)
|
||||
|
||||
# Colors and labels
|
||||
color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()]
|
||||
def get_color(node):
|
||||
if node in set_of_nodes_which_are_outputs:
|
||||
return IR_NODE_COLOR_MAPPING["output"]
|
||||
return IR_NODE_COLOR_MAPPING[type(node)]
|
||||
|
||||
color_map = [get_color(node) for node in graph.nodes()]
|
||||
|
||||
# For most types, we just pick the operation as the label, but for Input,
|
||||
# we take the name of the variable, ie the argument name of the function
|
||||
@@ -212,18 +219,19 @@ def draw_graph(
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
|
||||
def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
|
||||
def get_printable_graph(opgraph: OPGraph) -> str:
|
||||
"""Return a string representing a graph
|
||||
|
||||
Args:
|
||||
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
|
||||
graph (OPGraph): The graph that we want to draw
|
||||
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
"""
|
||||
|
||||
# Allow to pass either OPGraph or an nx graph, manage this here
|
||||
graph = graph.graph if isinstance(graph, OPGraph) else graph
|
||||
assert isinstance(opgraph, OPGraph)
|
||||
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
returned_str = ""
|
||||
|
||||
@@ -261,4 +269,7 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
|
||||
returned_str += f"\nreturn({return_part})"
|
||||
|
||||
return returned_str
|
||||
|
||||
@@ -11,23 +11,23 @@ from hdk.hnumpy import tracing
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,ref_graph_str",
|
||||
[
|
||||
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"),
|
||||
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)\nreturn(%1)"),
|
||||
(
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
|
||||
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
|
||||
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)",
|
||||
),
|
||||
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
|
||||
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
|
||||
(
|
||||
lambda x, y: x + 13 - y * (-21) * y + 44,
|
||||
"\n%0 = ConstantInput(44)"
|
||||
@@ -39,7 +39,24 @@ from hdk.hnumpy import tracing
|
||||
"\n%6 = Mul(3, 4)"
|
||||
"\n%7 = Mul(6, 3)"
|
||||
"\n%8 = Sub(5, 7)"
|
||||
"\n%9 = Add(8, 0)",
|
||||
"\n%9 = Add(8, 0)"
|
||||
"\nreturn(%9)",
|
||||
),
|
||||
# Multiple outputs
|
||||
(
|
||||
lambda x, y: (x + 1, x + y + 2),
|
||||
"\n%0 = x"
|
||||
"\n%1 = ConstantInput(1)"
|
||||
"\n%2 = ConstantInput(2)"
|
||||
"\n%3 = y"
|
||||
"\n%4 = Add(0, 1)"
|
||||
"\n%5 = Add(0, 3)"
|
||||
"\n%6 = Add(5, 2)"
|
||||
"\nreturn(%4, %6)",
|
||||
),
|
||||
(
|
||||
lambda x, y: (y, x),
|
||||
"\n%0 = y\n%1 = x\nreturn(%0, %1)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user