refactor: update GenericFunction to take an iterable as inputs

- also fix some corner cases in memory operations
- some small style changes

refs #600
This commit is contained in:
Arthur Meyre
2021-11-03 12:15:41 +01:00
parent d2faa90106
commit bff367137e
9 changed files with 227 additions and 118 deletions

View File

@@ -8,7 +8,7 @@ from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
from ..representation.intermediate import GenericFunction
from ..tracing.base_tracer import BaseTracer
from ..values import ClearTensor, EncryptedTensor
from ..values import TensorValue
from .table import LookupTable
@@ -97,14 +97,14 @@ class MultiLookupTable:
out_dtype = deepcopy(key.output.dtype)
out_shape = deepcopy(self.input_shape)
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if key.output.is_encrypted
else ClearTensor(out_dtype, out_shape)
generic_function_output_value = TensorValue(
out_dtype,
key.output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
input_base_value=key.output,
inputs=[deepcopy(key.output)],
arbitrary_func=MultiLookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -39,7 +39,7 @@ class LookupTable:
generic_function_output_value.dtype = self.output_dtype
traced_computation = GenericFunction(
input_base_value=key.output,
inputs=[deepcopy(key.output)],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -194,11 +194,11 @@ def convert_float_subgraph_to_fused_node(
# Create fused_node
fused_node = GenericFunction(
deepcopy(new_subgraph_variable_input.inputs[0]),
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
terminal_node
],
terminal_node.outputs[0],
inputs=[deepcopy(new_subgraph_variable_input.inputs[0])],
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
{0: x}
)[terminal_node],
output_value=terminal_node.outputs[0],
op_kind="TLU",
op_kwargs={
"float_op_subgraph": float_op_subgraph,

View File

@@ -290,7 +290,7 @@ class GenericFunctionKind(str, Enum):
class GenericFunction(IntermediateNode):
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
"""Node representing an arbitrary function with a single output, e.g. sin(x)."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
@@ -303,7 +303,7 @@ class GenericFunction(IntermediateNode):
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_attributes: Dict[str, Any]
_n_in: int = 1
_n_in: int
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/798 have a proper attribute
# system
@@ -311,7 +311,7 @@ class GenericFunction(IntermediateNode):
def __init__(
self,
input_base_value: BaseValue,
inputs: Iterable[BaseValue],
arbitrary_func: Callable,
output_value: BaseValue,
op_kind: Union[str, GenericFunctionKind],
@@ -320,8 +320,9 @@ class GenericFunction(IntermediateNode):
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value])
assert_true(len(self.inputs) == 1)
super().__init__(inputs)
self._n_in = len(self.inputs)
assert_true(self._n_in == 1) # TODO: remove in later parts of refactoring of #600
self.arbitrary_func = arbitrary_func
self.op_kind = GenericFunctionKind(op_kind)
self.op_args = op_args if op_args is not None else ()
@@ -330,14 +331,15 @@ class GenericFunction(IntermediateNode):
if op_attributes is not None:
self.op_attributes.update(op_attributes)
self.outputs = [deepcopy(output_value)]
self.outputs = [output_value]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
def label(self) -> str:
return self.op_name
@@ -350,12 +352,14 @@ class GenericFunction(IntermediateNode):
Returns:
List[Any]: The table.
"""
input_dtype = self.inputs[0].dtype
# Check the input is an unsigned integer to be able to build a table
assert isinstance(
input_dtype, Integer
), "get_table only works for an unsigned Integer input"
assert not input_dtype.is_signed, "get_table only works for an unsigned Integer input"
assert_true(
isinstance(input_dtype, Integer), "get_table only works for an unsigned Integer input"
)
input_dtype = cast(Integer, input_dtype)
assert_true(not input_dtype.is_signed, "get_table only works for an unsigned Integer input")
input_value_constructor = self.inputs[0].underlying_constructor
if input_value_constructor is None:

View File

@@ -7,11 +7,11 @@ import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_false, assert_true
from ..common.debugging.custom_assert import assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
from ..common.values import BaseValue, TensorValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
@@ -99,7 +99,7 @@ class NPTracer(BaseTracer):
generic_function_output_value = deepcopy(self.output)
generic_function_output_value.dtype = output_dtype
traced_computation = GenericFunction(
input_base_value=self.output,
inputs=[deepcopy(self.output)],
arbitrary_func=lambda x, dtype: x.astype(dtype),
output_value=generic_function_output_value,
op_kind="TLU",
@@ -171,7 +171,7 @@ class NPTracer(BaseTracer):
generic_function_output_value.dtype = common_output_dtypes[0]
traced_computation = GenericFunction(
input_base_value=input_tracers[0].output,
inputs=[deepcopy(input_tracers[0].output)],
arbitrary_func=unary_operator,
output_value=generic_function_output_value,
op_kind="TLU",
@@ -241,8 +241,9 @@ class NPTracer(BaseTracer):
generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output)
generic_function_output_value.dtype = common_output_dtypes[0]
# TODO: update inputs for #600 refactor
traced_computation = GenericFunction(
input_base_value=input_tracers[in_which_input_is_variable].output,
inputs=[deepcopy(input_tracers[in_which_input_is_variable].output)],
arbitrary_func=arbitrary_func,
output_value=generic_function_output_value,
op_kind="TLU",
@@ -312,25 +313,26 @@ class NPTracer(BaseTracer):
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
transpose_is_fusable = first_arg_output.is_scalar or first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape[::-1]
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
inputs=[deepcopy(first_arg_output)],
arbitrary_func=numpy.transpose,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.transpose",
op_attributes={"fusable": False},
op_attributes={"fusable": transpose_is_fusable},
)
output_tracer = self.__class__(
args,
@@ -358,25 +360,26 @@ class NPTracer(BaseTracer):
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
ravel_is_fusable = first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = (numpy.product(first_arg_output.shape),)
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
inputs=[deepcopy(first_arg_output)],
arbitrary_func=numpy.ravel,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.ravel",
op_attributes={"fusable": False},
op_attributes={"fusable": ravel_is_fusable},
)
output_tracer = self.__class__(
args,
@@ -422,27 +425,29 @@ class NPTracer(BaseTracer):
# Check shape compatibility
assert_true(
numpy.product(newshape) == numpy.product(first_arg_output.shape),
numpy.product(newshape) == first_arg_output.size,
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
)
reshape_is_fusable = newshape == first_arg_output.shape
out_dtype = first_arg_output.dtype
out_shape = newshape
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
inputs=[first_arg_output],
arbitrary_func=numpy.reshape,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs={"newshape": newshape},
op_name="np.reshape",
op_attributes={"fusable": False},
op_attributes={"fusable": reshape_is_fusable},
)
output_tracer = self.__class__(
[arg0],

View File

@@ -60,7 +60,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
output_arbitrary_function = ir.GenericFunction(
input_base_value=x,
inputs=[x],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",
@@ -104,7 +104,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
intermediate_arbitrary_function = ir.GenericFunction(
input_base_value=x,
inputs=[x],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -32,24 +32,24 @@ def no_fuse_dot(x):
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely(f, x):
def simple_create_fuse_opportunity(f, x):
"""No fuse because the function is explicitely marked as unfusable in our code."""
return f(x.astype(numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely_ravel(x):
"""No fuse ravel"""
return no_fuse_explicitely(numpy.ravel, x)
def ravel_cases(x):
"""Simple ravel cases"""
return simple_create_fuse_opportunity(numpy.ravel, x)
def no_fuse_explicitely_transpose(x):
"""No fuse transpose"""
return no_fuse_explicitely(numpy.transpose, x)
def transpose_cases(x):
"""Simple transpose cases"""
return simple_create_fuse_opportunity(numpy.transpose, x)
def no_fuse_explicitely_reshape(x):
"""No fuse reshape"""
return no_fuse_explicitely(lambda x: numpy.reshape(x, (1,)), x)
def reshape_cases(x, newshape):
"""Simple reshape cases"""
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
def simple_fuse_not_output(x):
@@ -182,41 +182,41 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
no_fuse_explicitely_ravel,
ravel_cases,
False,
get_func_params_int32(no_fuse_explicitely_ravel, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_ravel",
),
pytest.param(
no_fuse_explicitely_transpose,
transpose_cases,
False,
get_func_params_int32(no_fuse_explicitely_transpose, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_transpose",
),
pytest.param(
no_fuse_explicitely_reshape,
lambda x: reshape_cases(x, (20, 10)),
False,
get_func_params_int32(no_fuse_explicitely_reshape, scalar=False),
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
@@ -248,6 +248,34 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
None,
id="mix_x_and_y_and_call_f_with_rint",
),
pytest.param(
transpose_cases,
True,
get_func_params_int32(transpose_cases),
None,
id="transpose_cases scalar",
),
pytest.param(
transpose_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="transpose_cases ndim == 1",
),
pytest.param(
ravel_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="ravel_cases ndim == 1",
),
pytest.param(
lambda x: reshape_cases(x, (10, 20)),
True,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
None,
id="reshape_cases same shape",
),
],
)
def test_fuse_float_operations(

View File

@@ -36,7 +36,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
pytest.param(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
[EncryptedScalar(Integer(7, False))],
lambda x: x + 3,
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
@@ -47,7 +47,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
[EncryptedScalar(Integer(7, False))],
lambda x, y: x + y,
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
@@ -59,7 +59,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
[EncryptedScalar(Integer(7, False))],
lambda x, y: y[x],
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
@@ -71,7 +71,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
[EncryptedScalar(Integer(7, False))],
lambda x, y: y[3],
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
@@ -183,7 +183,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.transpose(x),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
@@ -194,7 +194,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.ravel(x),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
@@ -205,7 +205,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.reshape(x, (5, 3)),
output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
@@ -313,13 +313,13 @@ def test_evaluate(
),
(
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
@@ -328,14 +328,14 @@ def test_evaluate(
),
(
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_args=(1, 2, 3),
),
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
@@ -344,14 +344,14 @@ def test_evaluate(
),
(
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_kwargs={"tuple": (1, 2, 3)},
),
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",

View File

@@ -631,74 +631,146 @@ def test_nptracer_unsupported_operands(operation, tracer):
@pytest.mark.parametrize(
"function_to_trace,input_value,input_and_expected_output_tuples",
"function_to_trace,input_value_input_and_expected_output_tuples",
[
# Indirect calls, like numpy.function(x, ...)
(
lambda x: numpy.transpose(x),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4).reshape(2, 2), numpy.array([[0, 2], [1, 3]])),
(numpy.arange(4, 8).reshape(2, 2), numpy.array([[4, 6], [5, 7]])),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([[0, 2], [1, 3]]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4, 8).reshape(2, 2),
numpy.array([[4, 6], [5, 7]]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(42),
),
],
),
(
lambda x: numpy.transpose(x) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: numpy.ravel(x),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: numpy.reshape(x, (5, 3)) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
# Direct calls, like x.function(...)
(
lambda x: x.transpose() + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: x.ravel(),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: x.reshape((5, 3)) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
pytest.param(
lambda x: x.reshape((5, 3)),
[
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
),
],
)
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
"""Test function for managed by GenericFunction node"""
for input_, expected_output in input_and_expected_output_tuples:
def test_tracing_generic_function_memory_ops(
function_to_trace,
input_value_input_and_expected_output_tuples,
):
"""Test memory function managed by GenericFunction node"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output))
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
assert numpy.array_equal(expected_output, evaluated_output)