mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: update GenericFunction to take an iterable as inputs
- also fix some corner cases in memory operations - some small style changes refs #600
This commit is contained in:
@@ -32,24 +32,24 @@ def no_fuse_dot(x):
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_explicitely(f, x):
|
||||
def simple_create_fuse_opportunity(f, x):
|
||||
"""No fuse because the function is explicitely marked as unfusable in our code."""
|
||||
return f(x.astype(numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_explicitely_ravel(x):
|
||||
"""No fuse ravel"""
|
||||
return no_fuse_explicitely(numpy.ravel, x)
|
||||
def ravel_cases(x):
|
||||
"""Simple ravel cases"""
|
||||
return simple_create_fuse_opportunity(numpy.ravel, x)
|
||||
|
||||
|
||||
def no_fuse_explicitely_transpose(x):
|
||||
"""No fuse transpose"""
|
||||
return no_fuse_explicitely(numpy.transpose, x)
|
||||
def transpose_cases(x):
|
||||
"""Simple transpose cases"""
|
||||
return simple_create_fuse_opportunity(numpy.transpose, x)
|
||||
|
||||
|
||||
def no_fuse_explicitely_reshape(x):
|
||||
"""No fuse reshape"""
|
||||
return no_fuse_explicitely(lambda x: numpy.reshape(x, (1,)), x)
|
||||
def reshape_cases(x, newshape):
|
||||
"""Simple reshape cases"""
|
||||
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
@@ -182,41 +182,41 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_dot",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_ravel,
|
||||
ravel_cases,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_ravel, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_ravel",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_transpose,
|
||||
transpose_cases,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_transpose, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_transpose",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_reshape,
|
||||
lambda x: reshape_cases(x, (20, 10)),
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_reshape, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
@@ -248,6 +248,34 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
None,
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
get_func_params_int32(transpose_cases),
|
||||
None,
|
||||
id="transpose_cases scalar",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="transpose_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
ravel_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="ravel_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: reshape_cases(x, (10, 20)),
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
None,
|
||||
id="reshape_cases same shape",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fuse_float_operations(
|
||||
|
||||
Reference in New Issue
Block a user