feat: get_printable_graph is more precise

get_printable_graph prints the constant if the ArbitraryFunction has a baked constant
closes #584
This commit is contained in:
Benoit Chevallier-Mames
2021-10-11 14:49:16 +02:00
committed by Benoit Chevallier
parent 286dda79b2
commit dedbde93d0
4 changed files with 126 additions and 54 deletions

View File

@@ -112,6 +112,14 @@ def issue_130_c(x, y):
issue_130_c,
"%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
),
(
lambda x, y: numpy.arctan2(x, 42) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n",
),
(
lambda x, y: numpy.arctan2(43, x) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n",
),
],
)
@pytest.mark.parametrize(
@@ -221,9 +229,12 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedScalar(Integer(64, is_signed=False)),
EncryptedScalar(Integer(32, is_signed=True)),
),
"%0 = x # EncryptedScalar<Integer<unsigned, 64 bits>>"
"\n%1 = y # EncryptedScalar<Integer<signed, 32 bits>>"
"\n%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 65 bits>>"
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 64 bits>>"
"\n%1 = y "
" # EncryptedScalar<Integer<signed, 32 bits>>"
"\n%2 = Add(%0, %1) "
" # EncryptedScalar<Integer<signed, 65 bits>>"
"\nreturn(%2)\n",
),
(
@@ -232,11 +243,11 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedScalar(Integer(17, is_signed=False)),
EncryptedScalar(Integer(23, is_signed=False)),
),
"%0 = x "
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 17 bits>>"
"\n%1 = y "
"\n%1 = y "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\n%2 = Mul(%0, %1) "
"\n%2 = Mul(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\nreturn(%2)\n",
),
@@ -262,37 +273,37 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = TLU(%0) "
"\n%1 = TLU(%0) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(%0, %1) "
"\n%2 = Add(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(%2) "
"\n%3 = TLU(%2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\nreturn(%3)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(%0, %1) "
"\n%2 = Add(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(%2) "
"\n%3 = TLU(%2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%4 = TLU(%3) "
"\n%4 = TLU(%3) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%4)\n",
),
@@ -311,3 +322,31 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
def test_numpy_long_constant():
"Test get_printable_graph with long constant"
def all_explicit_operations(x):
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10))
intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate)
intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate)
return intermediate
op_graph = tracing.trace_numpy_function(
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))}
)
expected = """
%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%3 = Add(%1, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Sub(%3, %0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%6)
""".lstrip() # noqa: E501
assert get_printable_graph(op_graph, show_data_types=True) == expected