feat: adding a function to print a graph

refs #38
This commit is contained in:
Benoit Chevallier-Mames
2021-07-29 17:59:03 +02:00
committed by Benoit Chevallier
parent 1196b00c6b
commit c91ee858c5
3 changed files with 81 additions and 37 deletions

View File

@@ -1,2 +1,2 @@
"""HDK's module for debugging"""
from .draw_graph import draw_graph
from .draw_graph import draw_graph, get_printable_graph

View File

@@ -1,5 +1,5 @@
"""functions to draw the different graphs we can generate in the package, eg to debug"""
from typing import Dict, List
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import networkx as nx
@@ -194,3 +194,52 @@ def draw_graph(
plt.show(block=block_until_user_closes_graph)
# pylint: enable=too-many-locals
def get_printable_graph(graph: nx.DiGraph) -> str:
"""
Return a string representing a graph
Args:
graph (nx.DiGraph): The graph that we want to draw
Returns:
a string to print or save in a file
"""
returned_str = ""
i = 0
map_table: Dict[Any, int] = {}
for node in nx.topological_sort(graph):
if not isinstance(node, ir.Input):
what_to_print = node.__class__.__name__ + "("
# Find all the names of the current predecessors of the node
list_of_arg_name = []
for pred, index_list in graph.pred[node].items():
for index in index_list.values():
# Remark that we keep the index of the predecessor and its
# name, to print sources in the right order, which is
# important for eg non commutative operations
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
# Some checks, because the previous algorithm is not clear
assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name})
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
# Then, just print the predecessors in the right order
list_of_arg_name.sort()
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
else:
what_to_print = node.input_name
returned_str += f"\n%{i} = {what_to_print}"
map_table[node] = i
i += 1
return returned_str

View File

@@ -4,55 +4,50 @@ import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.debugging import draw_graph
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"lambda_f",
"lambda_f,ref_graph_str",
[
lambda x, y: x + y,
lambda x, y: x + x - y * y * y + x,
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
(
lambda x, y: x + x - y * y * y + x,
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
),
],
)
@pytest.mark.parametrize(
"x",
"x_y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
# pytest.param(
# EncryptedValue(Integer(64, is_signed=True)),
# id="Encrypted int",
# ),
# pytest.param(
# ClearValue(Integer(64, is_signed=False)),
# id="Clear uint",
# ),
# pytest.param(
# ClearValue(Integer(64, is_signed=True)),
# id="Clear int",
# ),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
# pytest.param(
# EncryptedValue(Integer(64, is_signed=True)),
# id="Encrypted int",
# ),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
(
EncryptedValue(Integer(64, is_signed=False)),
EncryptedValue(Integer(64, is_signed=False)),
),
id="Encrypted uint",
),
pytest.param(
(
EncryptedValue(Integer(64, is_signed=False)),
ClearValue(Integer(64, is_signed=False)),
),
id="Clear uint",
),
# pytest.param(
# ClearValue(Integer(64, is_signed=True)),
# id="Clear int",
# ),
],
)
def test_hnumpy_draw_graph(lambda_f, x, y):
"Test hnumpy draw_graph"
def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
"Test hnumpy get_printable_graph and draw_graph"
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
draw_graph(graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(graph)
print(f"\n{str_of_the_graph}\n")
assert str_of_the_graph == ref_graph_str