feat: append \n instead of prepending it in get_printable_graph

closes #222
This commit is contained in:
Benoit Chevallier-Mames
2021-08-31 16:45:05 +02:00
committed by Benoit Chevallier
parent 3b3714893b
commit 2badcecd0d
4 changed files with 61 additions and 63 deletions

View File

@@ -85,9 +85,7 @@ class CompilationArtifacts:
"""
drawing = draw_graph(operation_graph)
textual_representation = get_printable_graph(operation_graph, show_data_types=True)[1:]
# TODO: remove [1:] above after https://github.com/zama-ai/hdk/issues/222 is fixed
textual_representation = get_printable_graph(operation_graph, show_data_types=True)
self.drawings_of_operation_graphs[name] = drawing
self.textual_representations_of_operation_graphs[name] = textual_representation

View File

@@ -80,12 +80,12 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
if show_data_types:
new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}"
returned_str += f"\n{new_line}"
returned_str += f"{new_line}\n"
map_table[node] = i
i += 1
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
returned_str += f"\nreturn({return_part})"
returned_str += f"return({return_part})\n"
return returned_str

View File

@@ -213,10 +213,10 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
(4,),
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
"\n%0 = x # Integer<unsigned, 2 bits>"
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = y # Integer<unsigned, 2 bits>"
"\n%2 = Dot(0, 1) # Integer<unsigned, 6 bits>"
"\nreturn(%2)",
"\nreturn(%2)\n",
),
# pylint: enable=unnecessary-lambda
],
@@ -243,9 +243,9 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)

View File

@@ -40,26 +40,26 @@ def issue_130_c(x, y):
@pytest.mark.parametrize(
"lambda_f,ref_graph_str",
[
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)\nreturn(%1)"),
(lambda x, y: x + y, "%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)\n"),
(lambda x, y: x - y, "%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)\n"),
(lambda x, y: x + x, "%0 = x\n%1 = Add(0, 0)\nreturn(%1)\n"),
(
lambda x, y: x + x - y * y * y + x,
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)",
"%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)\n",
),
(lambda x, y: x + 1, "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: 1 + x, "\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = Constant(-1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: 3 * x, "\n%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x * 3, "\n%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x * (-3), "\n%0 = x\n%1 = Constant(-3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x - 11, "\n%0 = x\n%1 = Constant(11)\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: 11 - x, "\n%0 = Constant(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: (-11) - x, "\n%0 = Constant(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: x + 1, "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)\n"),
(lambda x, y: 1 + x, "%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2)\n"),
(lambda x, y: (-1) + x, "%0 = x\n%1 = Constant(-1)\n%2 = Add(0, 1)\nreturn(%2)\n"),
(lambda x, y: 3 * x, "%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)\n"),
(lambda x, y: x * 3, "%0 = x\n%1 = Constant(3)\n%2 = Mul(0, 1)\nreturn(%2)\n"),
(lambda x, y: x * (-3), "%0 = x\n%1 = Constant(-3)\n%2 = Mul(0, 1)\nreturn(%2)\n"),
(lambda x, y: x - 11, "%0 = x\n%1 = Constant(11)\n%2 = Sub(0, 1)\nreturn(%2)\n"),
(lambda x, y: 11 - x, "%0 = Constant(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)\n"),
(lambda x, y: (-11) - x, "%0 = Constant(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)\n"),
(
lambda x, y: x + 13 - y * (-21) * y + 44,
"\n%0 = Constant(44)"
"%0 = Constant(44)"
"\n%1 = x"
"\n%2 = Constant(13)"
"\n%3 = y"
@@ -69,48 +69,48 @@ def issue_130_c(x, y):
"\n%7 = Mul(6, 3)"
"\n%8 = Sub(5, 7)"
"\n%9 = Add(8, 0)"
"\nreturn(%9)",
"\nreturn(%9)\n",
),
# Multiple outputs
(
lambda x, y: (x + 1, x + y + 2),
"\n%0 = x"
"%0 = x"
"\n%1 = Constant(1)"
"\n%2 = Constant(2)"
"\n%3 = y"
"\n%4 = Add(0, 1)"
"\n%5 = Add(0, 3)"
"\n%6 = Add(5, 2)"
"\nreturn(%4, %6)",
"\nreturn(%4, %6)\n",
),
(
lambda x, y: (y, x),
"\n%0 = y\n%1 = x\nreturn(%0, %1)",
"%0 = y\n%1 = x\nreturn(%0, %1)\n",
),
(
lambda x, y: (x, x + 1),
"\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%0, %2)",
"%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%0, %2)\n",
),
(
lambda x, y: (x + 1, x + 1),
"\n%0 = x"
"%0 = x"
"\n%1 = Constant(1)"
"\n%2 = Constant(1)"
"\n%3 = Add(0, 1)"
"\n%4 = Add(0, 2)"
"\nreturn(%3, %4)",
"\nreturn(%3, %4)\n",
),
(
issue_130_a,
"\n%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2, %2)",
"%0 = x\n%1 = Constant(1)\n%2 = Add(0, 1)\nreturn(%2, %2)\n",
),
(
issue_130_b,
"\n%0 = x\n%1 = Constant(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)",
"%0 = x\n%1 = Constant(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)\n",
),
(
issue_130_c,
"\n%0 = Constant(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)",
"%0 = Constant(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)\n",
),
],
)
@@ -143,9 +143,9 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@@ -155,12 +155,12 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x\n%1 = TLU(0)\nreturn(%1)",
"%0 = x\n%1 = TLU(0)\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x\n%1 = Constant(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)",
"%0 = x\n%1 = Constant(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)\n",
),
],
)
@@ -173,9 +173,9 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@@ -189,7 +189,7 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
},
"\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)",
"%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)\n",
),
# pylint: enable=unnecessary-lambda
],
@@ -203,9 +203,9 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@@ -221,10 +221,10 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedValue(Integer(64, is_signed=False)),
EncryptedValue(Integer(32, is_signed=True)),
),
"\n%0 = x # Integer<unsigned, 64 bits>"
"%0 = x # Integer<unsigned, 64 bits>"
"\n%1 = y # Integer<signed, 32 bits>"
"\n%2 = Add(0, 1) # Integer<signed, 65 bits>"
"\nreturn(%2)",
"\nreturn(%2)\n",
),
(
lambda x, y: x * y,
@@ -232,10 +232,10 @@ def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedValue(Integer(17, is_signed=False)),
EncryptedValue(Integer(23, is_signed=False)),
),
"\n%0 = x # Integer<unsigned, 17 bits>"
"%0 = x # Integer<unsigned, 17 bits>"
"\n%1 = y # Integer<unsigned, 23 bits>"
"\n%2 = Mul(0, 1) # Integer<unsigned, 23 bits>"
"\nreturn(%2)",
"\nreturn(%2)\n",
),
],
)
@@ -247,9 +247,9 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@@ -259,28 +259,28 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = TLU(0) # Integer<unsigned, 4 bits>"
"\nreturn(%1)",
"\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = Constant(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"\nreturn(%3)",
"\nreturn(%3)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = Constant(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"\n%4 = TLU(3) # Integer<unsigned, 4 bits>"
"\nreturn(%4)",
"\nreturn(%4)\n",
),
],
)
@@ -293,7 +293,7 @@ def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)