refactor: make GenericFunction accept several inputs

- remove baked constants
- manage table generation for the updated node

closes #600
closes #822
This commit is contained in:
Arthur Meyre
2021-11-03 16:14:10 +01:00
parent 7f32cf7965
commit f530a0b739
13 changed files with 206 additions and 130 deletions

View File

@@ -114,11 +114,23 @@ def issue_130_c(x, y):
),
(
lambda x, y: numpy.arctan2(x, 42) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n",
"""%0 = y
%1 = x
%2 = Constant(42)
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
(
lambda x, y: numpy.arctan2(43, x) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n",
"""%0 = y
%1 = Constant(43)
%2 = x
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
],
)
@@ -416,14 +428,22 @@ def test_numpy_long_constant():
)
expected = """
%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%3 = Add(%1, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Sub(%3, %0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%6)
%0 = Constant([[100 101 ... 198 199]]) # ClearTensor<Integer<unsigned, 8 bits>, shape=(10, 10)>
%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor<Integer<unsigned, 5 bits>, shape=(1, 10)>
%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%3 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%5 = Add(%3, %4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%6 = Sub(%5, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%7 = np.arctan2(%1, %6) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%8 = np.arctan2(%0, %7) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%8)
""".lstrip() # noqa: E501
assert get_printable_graph(op_graph, show_data_types=True) == expected
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
assert str_of_the_graph == expected, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{expected}"
f"==================\n"
)