refactor: update GenericFunction to take an iterable as inputs

- also fix some corner cases in memory operations
- some small style changes

refs #600
This commit is contained in:
Arthur Meyre
2021-11-03 12:15:41 +01:00
parent d2faa90106
commit bff367137e
9 changed files with 227 additions and 118 deletions

View File

@@ -32,24 +32,24 @@ def no_fuse_dot(x):
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely(f, x):
def simple_create_fuse_opportunity(f, x):
"""No fuse because the function is explicitely marked as unfusable in our code."""
return f(x.astype(numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely_ravel(x):
"""No fuse ravel"""
return no_fuse_explicitely(numpy.ravel, x)
def ravel_cases(x):
"""Simple ravel cases"""
return simple_create_fuse_opportunity(numpy.ravel, x)
def no_fuse_explicitely_transpose(x):
"""No fuse transpose"""
return no_fuse_explicitely(numpy.transpose, x)
def transpose_cases(x):
"""Simple transpose cases"""
return simple_create_fuse_opportunity(numpy.transpose, x)
def no_fuse_explicitely_reshape(x):
"""No fuse reshape"""
return no_fuse_explicitely(lambda x: numpy.reshape(x, (1,)), x)
def reshape_cases(x, newshape):
"""Simple reshape cases"""
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
def simple_fuse_not_output(x):
@@ -182,41 +182,41 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
no_fuse_explicitely_ravel,
ravel_cases,
False,
get_func_params_int32(no_fuse_explicitely_ravel, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_ravel",
),
pytest.param(
no_fuse_explicitely_transpose,
transpose_cases,
False,
get_func_params_int32(no_fuse_explicitely_transpose, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_transpose",
),
pytest.param(
no_fuse_explicitely_reshape,
lambda x: reshape_cases(x, (20, 10)),
False,
get_func_params_int32(no_fuse_explicitely_reshape, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
@@ -248,6 +248,34 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
None,
id="mix_x_and_y_and_call_f_with_rint",
),
pytest.param(
transpose_cases,
True,
get_func_params_int32(transpose_cases),
None,
id="transpose_cases scalar",
),
pytest.param(
transpose_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="transpose_cases ndim == 1",
),
pytest.param(
ravel_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="ravel_cases ndim == 1",
),
pytest.param(
lambda x: reshape_cases(x, (10, 20)),
True,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
None,
id="reshape_cases same shape",
),
],
)
def test_fuse_float_operations(