diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index f9f1e02f3..c348e5c5d 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -16,6 +16,7 @@ from ..representation.intermediate import ( Add, Constant, Dot, + IndexConstant, Input, Mul, Sub, @@ -29,6 +30,7 @@ IR_NODE_COLOR_MAPPING = { Sub: "yellow", Mul: "green", UnivariateFunction: "orange", + IndexConstant: "black", Dot: "purple", "UnivariateFunction": "orange", "TLU": "grey", diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 0077ce794..cfca8696b 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -6,7 +6,13 @@ import networkx as nx from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph -from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction +from ..representation.intermediate import ( + Constant, + IndexConstant, + Input, + IntermediateNode, + UnivariateFunction, +) def output_data_type_to_string(node): @@ -124,6 +130,9 @@ def get_printable_graph( what_to_print += prefix_to_add_to_what_to_print what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) what_to_print += suffix_to_add_to_what_to_print + what_to_print += ( + f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else "" + ) what_to_print += ")" # This code doesn't work with more than a single output diff --git a/concrete/common/helpers/__init__.py b/concrete/common/helpers/__init__.py new file mode 100644 index 000000000..908c72ca0 --- /dev/null +++ b/concrete/common/helpers/__init__.py @@ -0,0 +1,3 @@ +"""Helpers for all kinds of tasks.""" + +from . import indexing_helpers diff --git a/concrete/common/helpers/indexing_helpers.py b/concrete/common/helpers/indexing_helpers.py new file mode 100644 index 000000000..80b77fede --- /dev/null +++ b/concrete/common/helpers/indexing_helpers.py @@ -0,0 +1,277 @@ +"""Helpers for indexing functionality.""" + +from typing import Tuple, Union + + +def format_indexing_element(indexing_element: Union[int, slice]) -> str: + """Format an indexing element. + + This is required mainly for slices. The reason is that string representation of slices + are very long and verbose. To give an example, `x[:, 2:]` will have the following index + `[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper, + it will be formatted as `[:, 2:]`. + + Args: + indexing_element (Union[int, slice]): indexing element to be formatted + + Returns: + str: formatted element + """ + + result = "" + if isinstance(indexing_element, slice): + if indexing_element.start is not None: + result += str(indexing_element.start) + result += ":" + if indexing_element.stop is not None: + result += str(indexing_element.stop) + if indexing_element.step is not None: + result += ":" + result += str(indexing_element.step) + else: + result += str(indexing_element) + return result.replace("\n", " ") + + +def validate_index( + index: Union[int, slice, Tuple[Union[int, slice], ...]], +) -> Tuple[Union[int, slice], ...]: + """Make sure index is valid and convert it to the tuple form. + + For example in `x[2]`, `index` is passed as `2`. + To make it easier to work with, this function converts index to `(2,)`. + + Args: + index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve + and return + + Returns: + Tuple[Union[int, slice], ...]: validated and improved index + """ + + if not isinstance(index, tuple): + index = (index,) + + for indexing_element in index: + valid = isinstance(indexing_element, (int, slice)) + + if isinstance(indexing_element, slice): + if ( + not (indexing_element.start is None or isinstance(indexing_element.start, int)) + or not (indexing_element.stop is None or isinstance(indexing_element.stop, int)) + or not (indexing_element.step is None or isinstance(indexing_element.step, int)) + ): + valid = False + + if not valid: + raise TypeError( + f"Only integers and integer slices can be used for indexing " + f"but you tried to use {format_indexing_element(indexing_element)} for indexing" + ) + + return index + + +def determine_output_shape( + input_shape: Tuple[int, ...], + index: Tuple[Union[int, slice], ...], +) -> Tuple[int, ...]: + """Determine the output shape from the input shape and the index. + + e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)` + for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)` + + Args: + input_shape (Tuple[int, ...]): shape of the input tensor that is indexed + index (Tuple[Union[int, slice], ...]): desired and validated index + + Returns: + Tuple[int, ...]: shape of the result of indexing + """ + + indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index] + index_str = f"[{', '.join(indexing_elements)}]" + + if len(index) > len(input_shape): + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"as the index has more elements than the number of dimensions of the tensor" + ) + + # indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :] + # indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :] + + # so let's replicate that behavior to make the rest of the code generic + index += (slice(None, None, None),) * (len(input_shape) - len(index)) + + output_shape = [] + for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): + if isinstance(indexing_element, int): # indexing removes the dimension + indexing_element = ( + indexing_element if indexing_element >= 0 else indexing_element + dimension_size + ) + if not 0 <= indexing_element < dimension_size: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because index is out of range for dimension {dimension}" + ) + elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension + output_shape.append( + determine_new_dimension_size( + indexing_element, + dimension_size, + dimension, + input_shape, + index_str, + ) + ) + + return tuple(output_shape) + + +def sanitize_start_index( + start: int, + dimension_size: int, + # the rest is used for detailed exception message + dimension: int, + input_shape: Tuple[int, ...], + index_str: str, +) -> int: + """Sanitize and check start index of a slice. + + Args: + start (int): start index being sanitized + dimension_size (int): size of the dimension the slice is applied to + dimension (int): index of the dimension being sliced (for better messages) + input_shape (Tuple[int, ...]): shape of the whole input (for better messages) + index_str (str): string representation of the whole index (for better messages) + + Returns: + int: sanitized start index + """ + + start = start if start >= 0 else start + dimension_size + if not 0 <= start < dimension_size: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because start index is out of range for dimension {dimension}" + ) + return start + + +def sanitize_stop_index( + stop: int, + dimension_size: int, + # the rest is used for detailed exception message + dimension: int, + input_shape: Tuple[int, ...], + index_str: str, +) -> int: + """Sanitize and check stop index of a slice. + + Args: + stop (int): stop index being sanitized + dimension_size (int): size of the dimension the slice is applied to + dimension (int): index of the dimension being sliced (for better messages) + input_shape (Tuple[int, ...]): shape of the whole input (for better messages) + index_str (str): string representation of the whole index (for better messages) + + Returns: + int: sanitized stop index + """ + + stop = stop if stop >= 0 else stop + dimension_size + if not 0 <= stop <= dimension_size: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because stop index is out of range for dimension {dimension}" + ) + return stop + + +def determine_new_dimension_size( + slice_: slice, + dimension_size: int, + # the rest is used for detailed exception message + dimension: int, + input_shape: Tuple[int, ...], + index_str: str, +) -> int: + """Determine the new size of a dimension from the old size and the slice applied to it. + + e.g., for `slice_=1:4` and `dimension_size=5`, returns `3` + for `slice_=::-1` and `dimension_size=5`, returns `5` + + You may want to check this page to learn more about how this function works + https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing + + Args: + slice_ (slice): slice being applied to the dimension + dimension_size (int): size of the dimension the slice is applied to + dimension (int): index of the dimension being sliced (for better messages) + input_shape (Tuple[int, ...]): shape of the whole input (for better messages) + index_str (str): string representation of the whole index (for better messages) + + Returns: + int: new size of the dimension + """ + + step = slice_.step if slice_.step is not None else 1 + + if step > 0: + start = slice_.start if slice_.start is not None else 0 + stop = slice_.stop if slice_.stop is not None else dimension_size + + start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str) + stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str) + + if start >= stop: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because start index is not less than stop index for dimension {dimension}" + ) + + size_before_stepping = stop - start + elif step < 0: + start = slice_.start if slice_.start is not None else dimension_size - 1 + stop = slice_.stop + + start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str) + + if stop is None: + # this is a weird case but it works as expected + # the issue is that it's impossible to slice whole vector reversed + # with a stop value different than none + + # if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)` + # here is what doesn't work (and this is expected it's just weird) + # + # ... + # `x[:-2:-1].shape == (1,)` + # `x[:-1:-1].shape == (0,)` (note that this is a hard error for us) + # `x[:0:-1].shape == (5,)` + # `x[:1:-1].shape == (4,)` + # ... + + size_before_stepping = start + 1 + else: + stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str) + + if stop >= start: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because step is negative and " + f"stop index is not less than start index for dimension {dimension}" + ) + + size_before_stepping = start - stop + else: + raise ValueError( + f"Tensor of shape {input_shape} cannot be indexed with {index_str} " + f"because step is zero for dimension {dimension}" + ) + + quotient = size_before_stepping // abs(step) + remainder = size_before_stepping % abs(step) + + return quotient + (remainder != 0) diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index d8694709f..4dcc60d86 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -71,6 +71,11 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]): return "only unsigned integer dot product is supported" + elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing + assert_true(len(outputs) == 1) + if not value_is_unsigned_integer(outputs[0]): + return "only unsigned integer tensor constant indexing is supported" + else: # pragma: no cover assert_not_reached("Non IntermediateNode object in the OPGraph") diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index c96029b64..884801b47 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from loguru import logger @@ -14,7 +14,15 @@ from ..data_types.dtypes_helpers import ( ) from ..data_types.integers import Integer from ..debugging.custom_assert import assert_true -from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue +from ..helpers import indexing_helpers +from ..values import ( + BaseValue, + ClearScalar, + ClearTensor, + EncryptedScalar, + EncryptedTensor, + TensorValue, +) IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -197,6 +205,56 @@ class Constant(IntermediateNode): return str(self.constant_data) +class IndexConstant(IntermediateNode): + """Node representing a constant indexing in the program. + + What we mean by constant indexing is that the index part of the operation is a constant. + Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]` + + The opposite is to have dynamic indexing, which this node does not support. + Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]` + """ + + _n_in: int = 1 + + index: Tuple[Union[int, slice], ...] + + def __init__( + self, + input_: BaseValue, + index: Union[int, slice, Tuple[Union[int, slice], ...]], + ) -> None: + super().__init__((input_,)) + + if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar: + raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}") + + self.index = indexing_helpers.validate_index(index) + + output_dtype = self.inputs[0].dtype + output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index) + + self.outputs = [ + EncryptedTensor(output_dtype, output_shape) + if self.inputs[0].is_encrypted + else ClearTensor(output_dtype, output_shape) + ] + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + return inputs[0][self.index] + + def label(self) -> str: + """Label of the node to show during drawings. + + It can be used for some other places after `"value"` below is replaced by `""`. + This note will no longer be necessary after #707 is addressed. + + """ + elements = [indexing_helpers.format_indexing_element(element) for element in self.index] + index = ", ".join(elements) + return f"value[{index}]" + + def flood_replace_none_values(table: list): """Use a flooding algorithm to replace None values. diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 4a6f450e4..68bacbabe 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -7,6 +7,7 @@ from ..debugging.custom_assert import assert_true from ..representation.intermediate import ( IR_MIX_VALUES_FUNC_ARG_NAME, Add, + IndexConstant, IntermediateNode, Mul, Sub, @@ -161,3 +162,7 @@ class BaseTracer(ABC): # the order, we need to do as in __rmul__, ie mostly a copy of __mul__ + # some changes __rmul__ = __mul__ + + def __getitem__(self, item): + traced_computation = IndexConstant(self.output, item) + return self.__class__([self], traced_computation, 0) diff --git a/concrete/numpy/np_indexing_helpers.py b/concrete/numpy/np_indexing_helpers.py new file mode 100644 index 000000000..945f66412 --- /dev/null +++ b/concrete/numpy/np_indexing_helpers.py @@ -0,0 +1,59 @@ +"""Helpers for indexing with numpy values functionality.""" + +from typing import Any + +import numpy + + +def should_sanitize(indexing_element: Any) -> bool: + """Decide whether to sanitize an indexing element or not. + + Sanitizing in this context means converting supported numpy values into python values. + + Args: + indexing_element (Any): the indexing element to decide sanitization. + + Returns: + bool: True if indexing element should be sanitized otherwise False. + """ + + return isinstance(indexing_element, numpy.integer) or ( + isinstance(indexing_element, numpy.ndarray) + and issubclass(indexing_element.dtype.type, numpy.integer) + and indexing_element.shape == () + ) + + +def process_indexing_element(indexing_element: Any) -> Any: + """Process an indexing element. + + Processing in this context means converting supported numpy values into python values. + (if they are decided to be sanitized) + + Args: + indexing_element (Any): the indexing element to sanitize. + + Returns: + Any: the sanitized indexing element. + """ + + if isinstance(indexing_element, slice): + + start = indexing_element.start + if should_sanitize(start): + start = int(start) + + stop = indexing_element.stop + if should_sanitize(stop): + stop = int(stop) + + step = indexing_element.step + if should_sanitize(step): + step = int(step) + + indexing_element = slice(start, stop, step) + + elif should_sanitize(indexing_element): + indexing_element = int(indexing_element) + + return indexing_element diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index f8316dc89..dbb2f25b6 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -18,6 +18,7 @@ from .np_dtypes_helpers import ( get_base_value_for_numpy_or_python_constant_data, get_numpy_function_output_dtype, ) +from .np_indexing_helpers import process_indexing_element SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES @@ -264,6 +265,14 @@ class NPTracer(BaseTracer): ) return output_tracer + def __getitem__(self, item): + if isinstance(item, tuple): + item = tuple(process_indexing_element(indexing_element) for indexing_element in item) + else: + item = process_indexing_element(item) + + return BaseTracer.__getitem__(self, item) + # Supported functions are either univariate or bivariate for which one of the two # sources is a constant # diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index c20588f20..0228525e8 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -116,6 +116,54 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En 20, id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])", ), + pytest.param( + ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)), + [ + numpy.array([1, 2, 3, 4], dtype=numpy.int32), + ], + 1, + id="IndexConstant, np.array([1, 2, 3, 4])[0]", + ), + pytest.param( + ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)), + [ + numpy.array([1, 2, 3, 4], dtype=numpy.int32), + ], + numpy.array([2, 3]), + id="IndexConstant, np.array([1, 2, 3, 4])[1:3]", + ), + pytest.param( + ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)), + [ + numpy.array([1, 2, 3, 4], dtype=numpy.int32), + ], + numpy.array([4, 3], dtype=numpy.int32), + id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]", + ), + pytest.param( + ir.IndexConstant( + EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1)) + ), + [ + numpy.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ], + dtype=numpy.int32, + ), + ], + numpy.array( + [ + [7, 6], + [11, 10], + ], + dtype=numpy.int32, + ), + id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]", + ), ], ) def test_evaluate( @@ -124,7 +172,10 @@ def test_evaluate( expected_result: int, ): """Test evaluate methods on IntermediateNodes""" - assert node.evaluate(input_data) == expected_result + if isinstance(expected_result, numpy.ndarray): + assert (node.evaluate(input_data) == expected_result).all() + else: + assert node.evaluate(input_data) == expected_result @pytest.mark.parametrize( diff --git a/tests/conftest.py b/tests/conftest.py index c30fce90b..56cfd08bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from concrete.common.representation.intermediate import ( Add, Constant, Dot, + IndexConstant, Input, IntermediateNode, Mul, @@ -147,6 +148,15 @@ def is_equivalent_input(lhs: Input, rhs: object) -> bool: ) +def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool: + """Helper function to check if an IndexConstant node is equivalent to an other object.""" + return ( + isinstance(rhs, IndexConstant) + and lhs.index == rhs.index + and is_equivalent_intermediate_node(lhs, rhs) + ) + + def is_equivalent_mul(lhs: Mul, rhs: object) -> bool: """Helper function to check if a Mul node is equivalent to an other object.""" return _is_equivalent_to_binary_commutative(lhs, rhs) @@ -171,6 +181,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { UnivariateFunction: is_equivalent_arbitrary_function, Constant: is_equivalent_constant, Dot: is_equivalent_dot, + IndexConstant: is_equivalent_index_constant, Input: is_equivalent_input, Mul: is_equivalent_mul, Sub: is_equivalent_sub, diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index bf583f13e..688875797 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -852,6 +852,20 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura "return(%2)\n" ), ), + pytest.param( + lambda x: x[0], + {"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))}, + [(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long + "%1 = IndexConstant(%0[0]) # EncryptedTensor, shape=(2,)>\n" # noqa: E501 # pylint: disable=line-too-long + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer tensor constant indexing is supported\n" # noqa: E501 # pylint: disable=line-too-long + "return(%1)\n" + ), + ), ], ) def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration): diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py new file mode 100644 index 000000000..78acd99ed --- /dev/null +++ b/tests/numpy/test_compile_constant_indexing.py @@ -0,0 +1,598 @@ +"""Test module for constant indexing.""" + +import numpy as np +import pytest + +from concrete.common.data_types import UnsignedInteger +from concrete.common.values import EncryptedScalar, EncryptedTensor +from concrete.numpy import compile_numpy_function_into_op_graph + + +@pytest.mark.parametrize( + "input_value,function_with_indexing,output_value", + [ + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2:], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1:], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1:], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2:], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:-2], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:2], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:3], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:-2], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:2], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3:3], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2:2], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2:3], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1:3], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:-2], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:2], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:3], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1:2], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1:3], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2:3], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[::-1], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-3::-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-2::-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1::-1], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0::-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1::-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2::-1], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:-3:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:-2:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:0:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:1:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2:0:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2:1:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1:1:-1], + EncryptedTensor(UnsignedInteger(1), shape=(1,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[-1:0:-1], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[:, :, :], + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0, :, :], + EncryptedTensor(UnsignedInteger(1), shape=(4, 5)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[:, 0, :], + EncryptedTensor(UnsignedInteger(1), shape=(3, 5)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[:, :, 0], + EncryptedTensor(UnsignedInteger(1), shape=(3, 4)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0, 0, :], + EncryptedTensor(UnsignedInteger(1), shape=(5,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0, :, 0], + EncryptedTensor(UnsignedInteger(1), shape=(4,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[:, 0, 0], + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0:, 1:, 2:], + EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[2:, 1:, 0:], + EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0], + EncryptedTensor(UnsignedInteger(1), shape=(4, 5)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0, 0], + EncryptedTensor(UnsignedInteger(1), shape=(5,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), + lambda x: x[0, 0, 0], + EncryptedScalar(UnsignedInteger(1)), + ), + ], +) +def test_constant_indexing( + default_compilation_configuration, + input_value, + function_with_indexing, + output_value, +): + """Test compile_numpy_function_into_op_graph with constant indexing""" + + inputset = [ + ( + np.random.randint( + input_value.dtype.min_value(), + input_value.dtype.max_value() + 1, + size=input_value.shape, + ), + ) + for _ in range(10) + ] + + opgraph = compile_numpy_function_into_op_graph( + function_with_indexing, + {"x": input_value}, + inputset, + default_compilation_configuration, + ) + + assert len(opgraph.output_nodes) == 1 + output_node = opgraph.output_nodes[0] + + assert len(output_node.outputs) == 1 + assert output_value == output_node.outputs[0] + + +@pytest.mark.parametrize( + "input_value,function_with_indexing,expected_error_type,expected_error_message", + [ + pytest.param( + EncryptedScalar(UnsignedInteger(1)), + lambda x: x[0], + TypeError, + "Only tensors can be indexed " + "but you tried to index EncryptedScalar>", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0.5], + TypeError, + "Only integers and integer slices can be used for indexing " + "but you tried to use 0.5 for indexing", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[1:5:0.5], # type: ignore + TypeError, + "Only integers and integer slices can be used for indexing " + "but you tried to use 1:5:0.5 for indexing", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0, 1], + ValueError, + "Tensor of shape (3,) cannot be indexed with [0, 1] " + "as the index has more elements than the number of dimensions of the tensor", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[5], + ValueError, + "Tensor of shape (3,) cannot be indexed with [5] " + "because index is out of range for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[5:], + ValueError, + "Tensor of shape (3,) cannot be indexed with [5:] " + "because start index is out of range for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:10], + ValueError, + "Tensor of shape (3,) cannot be indexed with [:10] " + "because stop index is out of range for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[2:0], + ValueError, + "Tensor of shape (3,) cannot be indexed with [2:0] " + "because start index is not less than stop index for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[5::-1], + ValueError, + "Tensor of shape (3,) cannot be indexed with [5::-1] " + "because start index is out of range for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[:10:-1], + ValueError, + "Tensor of shape (3,) cannot be indexed with [:10:-1] " + "because stop index is out of range for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[0:2:-1], + ValueError, + "Tensor of shape (3,) cannot be indexed with [0:2:-1] " + "because step is negative and stop index is not less than start index for dimension 0", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[::0], + ValueError, + "Tensor of shape (3,) cannot be indexed with [::0] " + "because step is zero for dimension 0", + ), + ], +) +def test_invalid_constant_indexing( + default_compilation_configuration, + input_value, + function_with_indexing, + expected_error_type, + expected_error_message, +): + """Test compile_numpy_function_into_op_graph with invalid constant indexing""" + + with pytest.raises(expected_error_type): + try: + inputset = [ + ( + np.random.randint( + input_value.dtype.min_value(), + input_value.dtype.max_value() + 1, + size=input_value.shape, + ), + ) + for _ in range(10) + ] + compile_numpy_function_into_op_graph( + function_with_indexing, + {"x": input_value}, + inputset, + default_compilation_configuration, + ) + except Exception as error: + assert str(error) == expected_error_message + raise + + +@pytest.mark.parametrize( + "input_value,function_with_indexing,output_value", + [ + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[np.uint32(0)], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[np.array(0)], + EncryptedScalar(UnsignedInteger(1)), + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[slice(np.array(2), np.array(0), np.array(-1))], + EncryptedTensor(UnsignedInteger(1), shape=(2,)), + ), + ], +) +def test_constant_indexing_with_numpy_integers( + default_compilation_configuration, + input_value, + function_with_indexing, + output_value, +): + """Test compile_numpy_function_into_op_graph with constant indexing with numpy integers""" + + inputset = [ + ( + np.random.randint( + input_value.dtype.min_value(), + input_value.dtype.max_value() + 1, + size=input_value.shape, + ), + ) + for _ in range(10) + ] + + opgraph = compile_numpy_function_into_op_graph( + function_with_indexing, + {"x": input_value}, + inputset, + default_compilation_configuration, + ) + + assert len(opgraph.output_nodes) == 1 + output_node = opgraph.output_nodes[0] + + assert len(output_node.outputs) == 1 + assert output_value == output_node.outputs[0] + + +@pytest.mark.parametrize( + "input_value,function_with_indexing,expected_error_type,expected_error_message", + [ + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[np.float32(1.5)], + TypeError, + "Only integers and integer slices can be used for indexing " + "but you tried to use 1.5 for indexing", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[np.array(1.5)], + TypeError, + "Only integers and integer slices can be used for indexing " + "but you tried to use 1.5 for indexing", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(1), shape=(3,)), + lambda x: x[np.array([1, 2])], + TypeError, + "Only integers and integer slices can be used for indexing " + "but you tried to use [1 2] for indexing", + ), + ], +) +def test_invalid_constant_indexing_with_numpy_values( + default_compilation_configuration, + input_value, + function_with_indexing, + expected_error_type, + expected_error_message, +): + """Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values""" + + with pytest.raises(expected_error_type): + try: + inputset = [ + ( + np.random.randint( + input_value.dtype.min_value(), + input_value.dtype.max_value() + 1, + size=input_value.shape, + ), + ) + for _ in range(10) + ] + compile_numpy_function_into_op_graph( + function_with_indexing, + {"x": input_value}, + inputset, + default_compilation_configuration, + ) + except Exception as error: + assert str(error) == expected_error_message + raise