From 2badcecd0d206516114a9a2e479b7a760d87ee14 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 31 Aug 2021 16:45:05 +0200 Subject: [PATCH] feat: append \n instead of prepending it in get_printable_graph closes #222 --- hdk/common/compilation/artifacts.py | 4 +- hdk/common/debugging/printing.py | 4 +- tests/hnumpy/test_compile.py | 10 +-- tests/hnumpy/test_debugging.py | 106 ++++++++++++++-------------- 4 files changed, 61 insertions(+), 63 deletions(-) diff --git a/hdk/common/compilation/artifacts.py b/hdk/common/compilation/artifacts.py index eb6a43bfb..28b16d6fa 100644 --- a/hdk/common/compilation/artifacts.py +++ b/hdk/common/compilation/artifacts.py @@ -85,9 +85,7 @@ class CompilationArtifacts: """ drawing = draw_graph(operation_graph) - textual_representation = get_printable_graph(operation_graph, show_data_types=True)[1:] - - # TODO: remove [1:] above after https://github.com/zama-ai/hdk/issues/222 is fixed + textual_representation = get_printable_graph(operation_graph, show_data_types=True) self.drawings_of_operation_graphs[name] = drawing self.textual_representations_of_operation_graphs[name] = textual_representation diff --git a/hdk/common/debugging/printing.py b/hdk/common/debugging/printing.py index 1cea9033e..1c721472a 100644 --- a/hdk/common/debugging/printing.py +++ b/hdk/common/debugging/printing.py @@ -80,12 +80,12 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: if show_data_types: new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}" - returned_str += f"\n{new_line}" + returned_str += f"{new_line}\n" map_table[node] = i i += 1 return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs]) - returned_str += f"\nreturn({return_part})" + returned_str += f"return({return_part})\n" return returned_str diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 06608bfda..e1d98ed30 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -213,10 +213,10 @@ def test_fail_compile(function, input_ranges, list_of_arg_names): (4,), # Remark that, when you do the dot of tensors of 4 values between 0 and 3, # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = y # Integer" "\n%2 = Dot(0, 1) # Integer" - "\nreturn(%2)", + "\nreturn(%2)\n", ), # pylint: enable=unnecessary-lambda ], @@ -243,9 +243,9 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str): ) str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" ) diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 5ef0aaed0..e69bd63fb 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -40,26 +40,26 @@ def issue_130_c(x, y): @pytest.mark.parametrize( "lambda_f,ref_graph_str", [ - (lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)"), - (lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)"), - (lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)\nreturn(%1)"), + (lambda x, y: x + y, "%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)\n"), + (lambda x, y: x - y, "%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)\n"), + (lambda x, y: x + x, "%0 = x\n%1 = Add(0, 0)\nreturn(%1)\n"), ( lambda x, y: x + x - y * y * y + x, - "\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)" - "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)", + "%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)" + "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)\n", ), - (lambda x, y: x + 1, "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)"), - (lambda x, y: 1 + x, "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)"), - (lambda x, y: (-1) + x, "\n%0 = x\n%1 = Constant(-1)\n%2 = Add(0, 1)\nreturn(%2)"), - (lambda x, y: 3 * x, "\n%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)"), - (lambda x, y: x * 3, "\n%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)"), - (lambda x, y: x * (-3), "\n%0 = x\n%1 = Constant(-3)\n%2 = Mul(0, 1)\nreturn(%2)"), - (lambda x, y: x - 11, "\n%0 = x\n%1 = Constant(11)\n%2 = Sub(0, 1)\nreturn(%2)"), - (lambda x, y: 11 - x, "\n%0 = Constant(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"), - (lambda x, y: (-11) - x, "\n%0 = Constant(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"), + (lambda x, y: x + 1, "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)\n"), + (lambda x, y: 1 + x, "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)\n"), + (lambda x, y: (-1) + x, "%0 = x\n%1 = Constant(-1)\n%2 = Add(0, 1)\nreturn(%2)\n"), + (lambda x, y: 3 * x, "%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)\n"), + (lambda x, y: x * 3, "%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)\n"), + (lambda x, y: x * (-3), "%0 = x\n%1 = Constant(-3)\n%2 = Mul(0, 1)\nreturn(%2)\n"), + (lambda x, y: x - 11, "%0 = x\n%1 = Constant(11)\n%2 = Sub(0, 1)\nreturn(%2)\n"), + (lambda x, y: 11 - x, "%0 = Constant(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)\n"), + (lambda x, y: (-11) - x, "%0 = Constant(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)\n"), ( lambda x, y: x + 13 - y * (-21) * y + 44, - "\n%0 = Constant(44)" + "%0 = Constant(44)" "\n%1 = x" "\n%2 = Constant(13)" "\n%3 = y" @@ -69,48 +69,48 @@ def issue_130_c(x, y): "\n%7 = Mul(6, 3)" "\n%8 = Sub(5, 7)" "\n%9 = Add(8, 0)" - "\nreturn(%9)", + "\nreturn(%9)\n", ), # Multiple outputs ( lambda x, y: (x + 1, x + y + 2), - "\n%0 = x" + "%0 = x" "\n%1 = Constant(1)" "\n%2 = Constant(2)" "\n%3 = y" "\n%4 = Add(0, 1)" "\n%5 = Add(0, 3)" "\n%6 = Add(5, 2)" - "\nreturn(%4, %6)", + "\nreturn(%4, %6)\n", ), ( lambda x, y: (y, x), - "\n%0 = y\n%1 = x\nreturn(%0, %1)", + "%0 = y\n%1 = x\nreturn(%0, %1)\n", ), ( lambda x, y: (x, x + 1), - "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%0, %2)", + "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%0, %2)\n", ), ( lambda x, y: (x + 1, x + 1), - "\n%0 = x" + "%0 = x" "\n%1 = Constant(1)" "\n%2 = Constant(1)" "\n%3 = Add(0, 1)" "\n%4 = Add(0, 2)" - "\nreturn(%3, %4)", + "\nreturn(%3, %4)\n", ), ( issue_130_a, - "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2, %2)", + "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2, %2)\n", ), ( issue_130_b, - "\n%0 = x\n%1 = Constant(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)", + "%0 = x\n%1 = Constant(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)\n", ), ( issue_130_c, - "\n%0 = Constant(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)", + "%0 = Constant(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)\n", ), ], ) @@ -143,9 +143,9 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): str_of_the_graph = get_printable_graph(graph) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" ) @@ -155,12 +155,12 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x\n%1 = TLU(0)\nreturn(%1)", + "%0 = x\n%1 = TLU(0)\nreturn(%1)\n", ), ( lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x\n%1 = Constant(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)", + "%0 = x\n%1 = Constant(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)\n", ), ], ) @@ -173,9 +173,9 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph str_of_the_graph = get_printable_graph(graph) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" ) @@ -189,7 +189,7 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), "y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), }, - "\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)", + "%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)\n", ), # pylint: enable=unnecessary-lambda ], @@ -203,9 +203,9 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): str_of_the_graph = get_printable_graph(graph) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" ) @@ -221,10 +221,10 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedValue(Integer(64, is_signed=False)), EncryptedValue(Integer(32, is_signed=True)), ), - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = y # Integer" "\n%2 = Add(0, 1) # Integer" - "\nreturn(%2)", + "\nreturn(%2)\n", ), ( lambda x, y: x * y, @@ -232,10 +232,10 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedValue(Integer(17, is_signed=False)), EncryptedValue(Integer(23, is_signed=False)), ), - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = y # Integer" "\n%2 = Mul(0, 1) # Integer" - "\nreturn(%2)", + "\nreturn(%2)\n", ), ], ) @@ -247,9 +247,9 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): str_of_the_graph = get_printable_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" ) @@ -259,28 +259,28 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = TLU(0) # Integer" - "\nreturn(%1)", + "\nreturn(%1)\n", ), ( lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = Constant(4) # Integer" "\n%2 = Add(0, 1) # Integer" "\n%3 = TLU(2) # Integer" - "\nreturn(%3)", + "\nreturn(%3)\n", ), ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x # Integer" + "%0 = x # Integer" "\n%1 = Constant(4) # Integer" "\n%2 = Add(0, 1) # Integer" "\n%3 = TLU(2) # Integer" "\n%4 = TLU(3) # Integer" - "\nreturn(%4)", + "\nreturn(%4)\n", ), ], ) @@ -293,7 +293,7 @@ def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref str_of_the_graph = get_printable_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot {str_of_the_graph}" - f"\n==================\nExpected {ref_graph_str}" - f"\n==================\n" + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" )