From c91ee858c50af42fa79358ad490f62f3b525c440 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Thu, 29 Jul 2021 17:59:03 +0200 Subject: [PATCH] feat: adding a function to print a graph refs #38 --- hdk/common/debugging/__init__.py | 2 +- hdk/common/debugging/draw_graph.py | 51 ++++++++++++++++++++++- tests/hnumpy/test_debugging.py | 65 ++++++++++++++---------------- 3 files changed, 81 insertions(+), 37 deletions(-) diff --git a/hdk/common/debugging/__init__.py b/hdk/common/debugging/__init__.py index b18d821bf..5be00bb60 100644 --- a/hdk/common/debugging/__init__.py +++ b/hdk/common/debugging/__init__.py @@ -1,2 +1,2 @@ """HDK's module for debugging""" -from .draw_graph import draw_graph +from .draw_graph import draw_graph, get_printable_graph diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index 46474d0b7..b012da781 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -1,5 +1,5 @@ """functions to draw the different graphs we can generate in the package, eg to debug""" -from typing import Dict, List +from typing import Any, Dict, List import matplotlib.pyplot as plt import networkx as nx @@ -194,3 +194,52 @@ def draw_graph( plt.show(block=block_until_user_closes_graph) # pylint: enable=too-many-locals + + +def get_printable_graph(graph: nx.DiGraph) -> str: + """ + Return a string representing a graph + + Args: + graph (nx.DiGraph): The graph that we want to draw + + Returns: + a string to print or save in a file + + """ + returned_str = "" + + i = 0 + map_table: Dict[Any, int] = {} + + for node in nx.topological_sort(graph): + + if not isinstance(node, ir.Input): + what_to_print = node.__class__.__name__ + "(" + + # Find all the names of the current predecessors of the node + list_of_arg_name = [] + + for pred, index_list in graph.pred[node].items(): + for index in index_list.values(): + # Remark that we keep the index of the predecessor and its + # name, to print sources in the right order, which is + # important for eg non commutative operations + list_of_arg_name += [(index["input_idx"], str(map_table[pred]))] + + # Some checks, because the previous algorithm is not clear + assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name}) + assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))) + + # Then, just print the predecessors in the right order + list_of_arg_name.sort() + what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" + + else: + what_to_print = node.input_name + + returned_str += f"\n%{i} = {what_to_print}" + map_table[node] = i + i += 1 + + return returned_str diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 2ade31562..1fe0d547b 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -4,55 +4,50 @@ import pytest from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import ClearValue, EncryptedValue -from hdk.common.debugging import draw_graph +from hdk.common.debugging import draw_graph, get_printable_graph from hdk.hnumpy import tracing @pytest.mark.parametrize( - "lambda_f", + "lambda_f,ref_graph_str", [ - lambda x, y: x + y, - lambda x, y: x + x - y * y * y + x, + (lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"), + (lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"), + ( + lambda x, y: x + x - y * y * y + x, + "\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)" + "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)", + ), ], ) @pytest.mark.parametrize( - "x", + "x_y", [ - pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), - # pytest.param( - # EncryptedValue(Integer(64, is_signed=True)), - # id="Encrypted int", - # ), - # pytest.param( - # ClearValue(Integer(64, is_signed=False)), - # id="Clear uint", - # ), - # pytest.param( - # ClearValue(Integer(64, is_signed=True)), - # id="Clear int", - # ), - ], -) -@pytest.mark.parametrize( - "y", - [ - pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), - # pytest.param( - # EncryptedValue(Integer(64, is_signed=True)), - # id="Encrypted int", - # ), pytest.param( - ClearValue(Integer(64, is_signed=False)), + ( + EncryptedValue(Integer(64, is_signed=False)), + EncryptedValue(Integer(64, is_signed=False)), + ), + id="Encrypted uint", + ), + pytest.param( + ( + EncryptedValue(Integer(64, is_signed=False)), + ClearValue(Integer(64, is_signed=False)), + ), id="Clear uint", ), - # pytest.param( - # ClearValue(Integer(64, is_signed=True)), - # id="Clear int", - # ), ], ) -def test_hnumpy_draw_graph(lambda_f, x, y): - "Test hnumpy draw_graph" +def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): + "Test hnumpy get_printable_graph and draw_graph" + x, y = x_y graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) draw_graph(graph, block_until_user_closes_graph=False) + + str_of_the_graph = get_printable_graph(graph) + + print(f"\n{str_of_the_graph}\n") + + assert str_of_the_graph == ref_graph_str