mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor: update GenericFunction to take an iterable as inputs
- also fix some corner cases in memory operations - some small style changes refs #600
This commit is contained in:
@@ -8,7 +8,7 @@ from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
|
||||
from ..representation.intermediate import GenericFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
from ..values import ClearTensor, EncryptedTensor
|
||||
from ..values import TensorValue
|
||||
from .table import LookupTable
|
||||
|
||||
|
||||
@@ -97,14 +97,14 @@ class MultiLookupTable:
|
||||
out_dtype = deepcopy(key.output.dtype)
|
||||
out_shape = deepcopy(self.input_shape)
|
||||
|
||||
generic_function_output_value = (
|
||||
EncryptedTensor(out_dtype, out_shape)
|
||||
if key.output.is_encrypted
|
||||
else ClearTensor(out_dtype, out_shape)
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
key.output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=key.output,
|
||||
inputs=[deepcopy(key.output)],
|
||||
arbitrary_func=MultiLookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -39,7 +39,7 @@ class LookupTable:
|
||||
generic_function_output_value.dtype = self.output_dtype
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=key.output,
|
||||
inputs=[deepcopy(key.output)],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -194,11 +194,11 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
# Create fused_node
|
||||
fused_node = GenericFunction(
|
||||
deepcopy(new_subgraph_variable_input.inputs[0]),
|
||||
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
|
||||
terminal_node
|
||||
],
|
||||
terminal_node.outputs[0],
|
||||
inputs=[deepcopy(new_subgraph_variable_input.inputs[0])],
|
||||
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
|
||||
{0: x}
|
||||
)[terminal_node],
|
||||
output_value=terminal_node.outputs[0],
|
||||
op_kind="TLU",
|
||||
op_kwargs={
|
||||
"float_op_subgraph": float_op_subgraph,
|
||||
|
||||
@@ -290,7 +290,7 @@ class GenericFunctionKind(str, Enum):
|
||||
|
||||
|
||||
class GenericFunction(IntermediateNode):
|
||||
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
|
||||
"""Node representing an arbitrary function with a single output, e.g. sin(x)."""
|
||||
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
@@ -303,7 +303,7 @@ class GenericFunction(IntermediateNode):
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_attributes: Dict[str, Any]
|
||||
_n_in: int = 1
|
||||
_n_in: int
|
||||
|
||||
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/798 have a proper attribute
|
||||
# system
|
||||
@@ -311,7 +311,7 @@ class GenericFunction(IntermediateNode):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_base_value: BaseValue,
|
||||
inputs: Iterable[BaseValue],
|
||||
arbitrary_func: Callable,
|
||||
output_value: BaseValue,
|
||||
op_kind: Union[str, GenericFunctionKind],
|
||||
@@ -320,8 +320,9 @@ class GenericFunction(IntermediateNode):
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
op_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value])
|
||||
assert_true(len(self.inputs) == 1)
|
||||
super().__init__(inputs)
|
||||
self._n_in = len(self.inputs)
|
||||
assert_true(self._n_in == 1) # TODO: remove in later parts of refactoring of #600
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_kind = GenericFunctionKind(op_kind)
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
@@ -330,14 +331,15 @@ class GenericFunction(IntermediateNode):
|
||||
if op_attributes is not None:
|
||||
self.op_attributes.update(op_attributes)
|
||||
|
||||
self.outputs = [deepcopy(output_value)]
|
||||
self.outputs = [output_value]
|
||||
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
|
||||
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
@@ -350,12 +352,14 @@ class GenericFunction(IntermediateNode):
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
"""
|
||||
|
||||
input_dtype = self.inputs[0].dtype
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
assert isinstance(
|
||||
input_dtype, Integer
|
||||
), "get_table only works for an unsigned Integer input"
|
||||
assert not input_dtype.is_signed, "get_table only works for an unsigned Integer input"
|
||||
assert_true(
|
||||
isinstance(input_dtype, Integer), "get_table only works for an unsigned Integer input"
|
||||
)
|
||||
input_dtype = cast(Integer, input_dtype)
|
||||
assert_true(not input_dtype.is_signed, "get_table only works for an unsigned Integer input")
|
||||
|
||||
input_value_constructor = self.inputs[0].underlying_constructor
|
||||
if input_value_constructor is None:
|
||||
|
||||
@@ -7,11 +7,11 @@ import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import assert_false, assert_true
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
@@ -99,7 +99,7 @@ class NPTracer(BaseTracer):
|
||||
generic_function_output_value = deepcopy(self.output)
|
||||
generic_function_output_value.dtype = output_dtype
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=self.output,
|
||||
inputs=[deepcopy(self.output)],
|
||||
arbitrary_func=lambda x, dtype: x.astype(dtype),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -171,7 +171,7 @@ class NPTracer(BaseTracer):
|
||||
generic_function_output_value.dtype = common_output_dtypes[0]
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
inputs=[deepcopy(input_tracers[0].output)],
|
||||
arbitrary_func=unary_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -241,8 +241,9 @@ class NPTracer(BaseTracer):
|
||||
generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output)
|
||||
generic_function_output_value.dtype = common_output_dtypes[0]
|
||||
|
||||
# TODO: update inputs for #600 refactor
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=input_tracers[in_which_input_is_variable].output,
|
||||
inputs=[deepcopy(input_tracers[in_which_input_is_variable].output)],
|
||||
arbitrary_func=arbitrary_func,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -312,25 +313,26 @@ class NPTracer(BaseTracer):
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
assert_false(first_arg_output.is_scalar)
|
||||
|
||||
transpose_is_fusable = first_arg_output.is_scalar or first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape[::-1]
|
||||
|
||||
generic_function_output_value = (
|
||||
EncryptedTensor(out_dtype, out_shape)
|
||||
if first_arg_output.is_encrypted
|
||||
else ClearTensor(out_dtype, out_shape)
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=numpy.transpose,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.transpose",
|
||||
op_attributes={"fusable": False},
|
||||
op_attributes={"fusable": transpose_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
@@ -358,25 +360,26 @@ class NPTracer(BaseTracer):
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
assert_false(first_arg_output.is_scalar)
|
||||
|
||||
ravel_is_fusable = first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = (numpy.product(first_arg_output.shape),)
|
||||
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
|
||||
|
||||
generic_function_output_value = (
|
||||
EncryptedTensor(out_dtype, out_shape)
|
||||
if first_arg_output.is_encrypted
|
||||
else ClearTensor(out_dtype, out_shape)
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=numpy.ravel,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.ravel",
|
||||
op_attributes={"fusable": False},
|
||||
op_attributes={"fusable": ravel_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
@@ -422,27 +425,29 @@ class NPTracer(BaseTracer):
|
||||
|
||||
# Check shape compatibility
|
||||
assert_true(
|
||||
numpy.product(newshape) == numpy.product(first_arg_output.shape),
|
||||
numpy.product(newshape) == first_arg_output.size,
|
||||
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
|
||||
)
|
||||
|
||||
reshape_is_fusable = newshape == first_arg_output.shape
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = newshape
|
||||
|
||||
generic_function_output_value = (
|
||||
EncryptedTensor(out_dtype, out_shape)
|
||||
if first_arg_output.is_encrypted
|
||||
else ClearTensor(out_dtype, out_shape)
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.reshape,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs={"newshape": newshape},
|
||||
op_name="np.reshape",
|
||||
op_attributes={"fusable": False},
|
||||
op_attributes={"fusable": reshape_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[arg0],
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
|
||||
output_arbitrary_function = ir.GenericFunction(
|
||||
input_base_value=x,
|
||||
inputs=[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -104,7 +104,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
|
||||
intermediate_arbitrary_function = ir.GenericFunction(
|
||||
input_base_value=x,
|
||||
inputs=[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -32,24 +32,24 @@ def no_fuse_dot(x):
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_explicitely(f, x):
|
||||
def simple_create_fuse_opportunity(f, x):
|
||||
"""No fuse because the function is explicitely marked as unfusable in our code."""
|
||||
return f(x.astype(numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_explicitely_ravel(x):
|
||||
"""No fuse ravel"""
|
||||
return no_fuse_explicitely(numpy.ravel, x)
|
||||
def ravel_cases(x):
|
||||
"""Simple ravel cases"""
|
||||
return simple_create_fuse_opportunity(numpy.ravel, x)
|
||||
|
||||
|
||||
def no_fuse_explicitely_transpose(x):
|
||||
"""No fuse transpose"""
|
||||
return no_fuse_explicitely(numpy.transpose, x)
|
||||
def transpose_cases(x):
|
||||
"""Simple transpose cases"""
|
||||
return simple_create_fuse_opportunity(numpy.transpose, x)
|
||||
|
||||
|
||||
def no_fuse_explicitely_reshape(x):
|
||||
"""No fuse reshape"""
|
||||
return no_fuse_explicitely(lambda x: numpy.reshape(x, (1,)), x)
|
||||
def reshape_cases(x, newshape):
|
||||
"""Simple reshape cases"""
|
||||
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
@@ -182,41 +182,41 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_dot",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_ravel,
|
||||
ravel_cases,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_ravel, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_ravel",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_transpose,
|
||||
transpose_cases,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_transpose, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_transpose",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_explicitely_reshape,
|
||||
lambda x: reshape_cases(x, (20, 10)),
|
||||
False,
|
||||
get_func_params_int32(no_fuse_explicitely_reshape, scalar=False),
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
@@ -248,6 +248,34 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
None,
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
get_func_params_int32(transpose_cases),
|
||||
None,
|
||||
id="transpose_cases scalar",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="transpose_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
ravel_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="ravel_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: reshape_cases(x, (10, 20)),
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
None,
|
||||
id="reshape_cases same shape",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fuse_float_operations(
|
||||
|
||||
@@ -36,7 +36,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x: x + 3,
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
@@ -47,7 +47,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: x + y,
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
@@ -59,7 +59,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: y[x],
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
@@ -71,7 +71,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: y[3],
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
@@ -183,7 +183,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.transpose(x),
|
||||
EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
@@ -194,7 +194,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.ravel(x),
|
||||
EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
@@ -205,7 +205,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
@@ -313,13 +313,13 @@ def test_evaluate(
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
@@ -328,14 +328,14 @@ def test_evaluate(
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
op_args=(1, 2, 3),
|
||||
),
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
@@ -344,14 +344,14 @@ def test_evaluate(
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
op_kwargs={"tuple": (1, 2, 3)},
|
||||
),
|
||||
ir.GenericFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -631,74 +631,146 @@ def test_nptracer_unsupported_operands(operation, tracer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value,input_and_expected_output_tuples",
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
# Indirect calls, like numpy.function(x, ...)
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
[
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([[0, 2], [1, 3]])),
|
||||
(numpy.arange(4, 8).reshape(2, 2), numpy.array([[4, 6], [5, 7]])),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([[0, 2], [1, 3]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4, 8).reshape(2, 2),
|
||||
numpy.array([[4, 6], [5, 7]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(42),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.transpose(x) + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
[
|
||||
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)) + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
# Direct calls, like x.function(...)
|
||||
(
|
||||
lambda x: x.transpose() + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.ravel(),
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
[
|
||||
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.reshape((5, 3)) + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3)),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
|
||||
"""Test function for managed by GenericFunction node"""
|
||||
for input_, expected_output in input_and_expected_output_tuples:
|
||||
def test_tracing_generic_function_memory_ops(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output))
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user