feat(tracing): implement tracing of constant indexing

This commit is contained in:
Umut
2021-10-20 17:23:16 +03:00
parent afb342aec3
commit 65af96253b
13 changed files with 1105 additions and 4 deletions

View File

@@ -16,6 +16,7 @@ from ..representation.intermediate import (
Add,
Constant,
Dot,
IndexConstant,
Input,
Mul,
Sub,
@@ -29,6 +30,7 @@ IR_NODE_COLOR_MAPPING = {
Sub: "yellow",
Mul: "green",
UnivariateFunction: "orange",
IndexConstant: "black",
Dot: "purple",
"UnivariateFunction": "orange",
"TLU": "grey",

View File

@@ -6,7 +6,13 @@ import networkx as nx
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
from ..representation.intermediate import (
Constant,
IndexConstant,
Input,
IntermediateNode,
UnivariateFunction,
)
def output_data_type_to_string(node):
@@ -124,6 +130,9 @@ def get_printable_graph(
what_to_print += prefix_to_add_to_what_to_print
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
what_to_print += suffix_to_add_to_what_to_print
what_to_print += (
f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else ""
)
what_to_print += ")"
# This code doesn't work with more than a single output

View File

@@ -0,0 +1,3 @@
"""Helpers for all kinds of tasks."""
from . import indexing_helpers

View File

@@ -0,0 +1,277 @@
"""Helpers for indexing functionality."""
from typing import Tuple, Union
def format_indexing_element(indexing_element: Union[int, slice]) -> str:
"""Format an indexing element.
This is required mainly for slices. The reason is that string representation of slices
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
it will be formatted as `[:, 2:]`.
Args:
indexing_element (Union[int, slice]): indexing element to be formatted
Returns:
str: formatted element
"""
result = ""
if isinstance(indexing_element, slice):
if indexing_element.start is not None:
result += str(indexing_element.start)
result += ":"
if indexing_element.stop is not None:
result += str(indexing_element.stop)
if indexing_element.step is not None:
result += ":"
result += str(indexing_element.step)
else:
result += str(indexing_element)
return result.replace("\n", " ")
def validate_index(
index: Union[int, slice, Tuple[Union[int, slice], ...]],
) -> Tuple[Union[int, slice], ...]:
"""Make sure index is valid and convert it to the tuple form.
For example in `x[2]`, `index` is passed as `2`.
To make it easier to work with, this function converts index to `(2,)`.
Args:
index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve
and return
Returns:
Tuple[Union[int, slice], ...]: validated and improved index
"""
if not isinstance(index, tuple):
index = (index,)
for indexing_element in index:
valid = isinstance(indexing_element, (int, slice))
if isinstance(indexing_element, slice):
if (
not (indexing_element.start is None or isinstance(indexing_element.start, int))
or not (indexing_element.stop is None or isinstance(indexing_element.stop, int))
or not (indexing_element.step is None or isinstance(indexing_element.step, int))
):
valid = False
if not valid:
raise TypeError(
f"Only integers and integer slices can be used for indexing "
f"but you tried to use {format_indexing_element(indexing_element)} for indexing"
)
return index
def determine_output_shape(
input_shape: Tuple[int, ...],
index: Tuple[Union[int, slice], ...],
) -> Tuple[int, ...]:
"""Determine the output shape from the input shape and the index.
e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)`
for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)`
Args:
input_shape (Tuple[int, ...]): shape of the input tensor that is indexed
index (Tuple[Union[int, slice], ...]): desired and validated index
Returns:
Tuple[int, ...]: shape of the result of indexing
"""
indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index]
index_str = f"[{', '.join(indexing_elements)}]"
if len(index) > len(input_shape):
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"as the index has more elements than the number of dimensions of the tensor"
)
# indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :]
# indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :]
# so let's replicate that behavior to make the rest of the code generic
index += (slice(None, None, None),) * (len(input_shape) - len(index))
output_shape = []
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, int): # indexing removes the dimension
indexing_element = (
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
if not 0 <= indexing_element < dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because index is out of range for dimension {dimension}"
)
elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension
output_shape.append(
determine_new_dimension_size(
indexing_element,
dimension_size,
dimension,
input_shape,
index_str,
)
)
return tuple(output_shape)
def sanitize_start_index(
start: int,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Sanitize and check start index of a slice.
Args:
start (int): start index being sanitized
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: sanitized start index
"""
start = start if start >= 0 else start + dimension_size
if not 0 <= start < dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because start index is out of range for dimension {dimension}"
)
return start
def sanitize_stop_index(
stop: int,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Sanitize and check stop index of a slice.
Args:
stop (int): stop index being sanitized
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: sanitized stop index
"""
stop = stop if stop >= 0 else stop + dimension_size
if not 0 <= stop <= dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because stop index is out of range for dimension {dimension}"
)
return stop
def determine_new_dimension_size(
slice_: slice,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Determine the new size of a dimension from the old size and the slice applied to it.
e.g., for `slice_=1:4` and `dimension_size=5`, returns `3`
for `slice_=::-1` and `dimension_size=5`, returns `5`
You may want to check this page to learn more about how this function works
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing
Args:
slice_ (slice): slice being applied to the dimension
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: new size of the dimension
"""
step = slice_.step if slice_.step is not None else 1
if step > 0:
start = slice_.start if slice_.start is not None else 0
stop = slice_.stop if slice_.stop is not None else dimension_size
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
if start >= stop:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because start index is not less than stop index for dimension {dimension}"
)
size_before_stepping = stop - start
elif step < 0:
start = slice_.start if slice_.start is not None else dimension_size - 1
stop = slice_.stop
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
if stop is None:
# this is a weird case but it works as expected
# the issue is that it's impossible to slice whole vector reversed
# with a stop value different than none
# if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)`
# here is what doesn't work (and this is expected it's just weird)
#
# ...
# `x[:-2:-1].shape == (1,)`
# `x[:-1:-1].shape == (0,)` (note that this is a hard error for us)
# `x[:0:-1].shape == (5,)`
# `x[:1:-1].shape == (4,)`
# ...
size_before_stepping = start + 1
else:
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
if stop >= start:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because step is negative and "
f"stop index is not less than start index for dimension {dimension}"
)
size_before_stepping = start - stop
else:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because step is zero for dimension {dimension}"
)
quotient = size_before_stepping // abs(step)
remainder = size_before_stepping % abs(step)
return quotient + (remainder != 0)

View File

@@ -71,6 +71,11 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
return "only unsigned integer dot product is supported"
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
assert_true(len(outputs) == 1)
if not value_is_unsigned_integer(outputs[0]):
return "only unsigned integer tensor constant indexing is supported"
else: # pragma: no cover
assert_not_reached("Non IntermediateNode object in the OPGraph")

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from loguru import logger
@@ -14,7 +14,15 @@ from ..data_types.dtypes_helpers import (
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true
from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue
from ..helpers import indexing_helpers
from ..values import (
BaseValue,
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
TensorValue,
)
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
@@ -197,6 +205,56 @@ class Constant(IntermediateNode):
return str(self.constant_data)
class IndexConstant(IntermediateNode):
"""Node representing a constant indexing in the program.
What we mean by constant indexing is that the index part of the operation is a constant.
Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]`
The opposite is to have dynamic indexing, which this node does not support.
Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]`
"""
_n_in: int = 1
index: Tuple[Union[int, slice], ...]
def __init__(
self,
input_: BaseValue,
index: Union[int, slice, Tuple[Union[int, slice], ...]],
) -> None:
super().__init__((input_,))
if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar:
raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}")
self.index = indexing_helpers.validate_index(index)
output_dtype = self.inputs[0].dtype
output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index)
self.outputs = [
EncryptedTensor(output_dtype, output_shape)
if self.inputs[0].is_encrypted
else ClearTensor(output_dtype, output_shape)
]
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0][self.index]
def label(self) -> str:
"""Label of the node to show during drawings.
It can be used for some other places after `"value"` below is replaced by `""`.
This note will no longer be necessary after #707 is addressed.
"""
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
index = ", ".join(elements)
return f"value[{index}]"
def flood_replace_none_values(table: list):
"""Use a flooding algorithm to replace None values.

View File

@@ -7,6 +7,7 @@ from ..debugging.custom_assert import assert_true
from ..representation.intermediate import (
IR_MIX_VALUES_FUNC_ARG_NAME,
Add,
IndexConstant,
IntermediateNode,
Mul,
Sub,
@@ -161,3 +162,7 @@ class BaseTracer(ABC):
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
# some changes
__rmul__ = __mul__
def __getitem__(self, item):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)

View File

@@ -0,0 +1,59 @@
"""Helpers for indexing with numpy values functionality."""
from typing import Any
import numpy
def should_sanitize(indexing_element: Any) -> bool:
"""Decide whether to sanitize an indexing element or not.
Sanitizing in this context means converting supported numpy values into python values.
Args:
indexing_element (Any): the indexing element to decide sanitization.
Returns:
bool: True if indexing element should be sanitized otherwise False.
"""
return isinstance(indexing_element, numpy.integer) or (
isinstance(indexing_element, numpy.ndarray)
and issubclass(indexing_element.dtype.type, numpy.integer)
and indexing_element.shape == ()
)
def process_indexing_element(indexing_element: Any) -> Any:
"""Process an indexing element.
Processing in this context means converting supported numpy values into python values.
(if they are decided to be sanitized)
Args:
indexing_element (Any): the indexing element to sanitize.
Returns:
Any: the sanitized indexing element.
"""
if isinstance(indexing_element, slice):
start = indexing_element.start
if should_sanitize(start):
start = int(start)
stop = indexing_element.stop
if should_sanitize(stop):
stop = int(stop)
step = indexing_element.step
if should_sanitize(step):
step = int(step)
indexing_element = slice(start, stop, step)
elif should_sanitize(indexing_element):
indexing_element = int(indexing_element)
return indexing_element

View File

@@ -18,6 +18,7 @@ from .np_dtypes_helpers import (
get_base_value_for_numpy_or_python_constant_data,
get_numpy_function_output_dtype,
)
from .np_indexing_helpers import process_indexing_element
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
@@ -264,6 +265,14 @@ class NPTracer(BaseTracer):
)
return output_tracer
def __getitem__(self, item):
if isinstance(item, tuple):
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
else:
item = process_indexing_element(item)
return BaseTracer.__getitem__(self, item)
# Supported functions are either univariate or bivariate for which one of the two
# sources is a constant
#

View File

@@ -116,6 +116,54 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
20,
id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
1,
id="IndexConstant, np.array([1, 2, 3, 4])[0]",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
numpy.array([2, 3]),
id="IndexConstant, np.array([1, 2, 3, 4])[1:3]",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
numpy.array([4, 3], dtype=numpy.int32),
id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]",
),
pytest.param(
ir.IndexConstant(
EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1))
),
[
numpy.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
],
dtype=numpy.int32,
),
],
numpy.array(
[
[7, 6],
[11, 10],
],
dtype=numpy.int32,
),
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
),
],
)
def test_evaluate(
@@ -124,7 +172,10 @@ def test_evaluate(
expected_result: int,
):
"""Test evaluate methods on IntermediateNodes"""
assert node.evaluate(input_data) == expected_result
if isinstance(expected_result, numpy.ndarray):
assert (node.evaluate(input_data) == expected_result).all()
else:
assert node.evaluate(input_data) == expected_result
@pytest.mark.parametrize(

View File

@@ -14,6 +14,7 @@ from concrete.common.representation.intermediate import (
Add,
Constant,
Dot,
IndexConstant,
Input,
IntermediateNode,
Mul,
@@ -147,6 +148,15 @@ def is_equivalent_input(lhs: Input, rhs: object) -> bool:
)
def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool:
"""Helper function to check if an IndexConstant node is equivalent to an other object."""
return (
isinstance(rhs, IndexConstant)
and lhs.index == rhs.index
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
"""Helper function to check if a Mul node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
@@ -171,6 +181,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
UnivariateFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,
IndexConstant: is_equivalent_index_constant,
Input: is_equivalent_input,
Mul: is_equivalent_mul,
Sub: is_equivalent_sub,

View File

@@ -852,6 +852,20 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"return(%2)\n"
),
),
pytest.param(
lambda x: x[0],
{"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))},
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<signed, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = IndexConstant(%0[0]) # EncryptedTensor<Integer<signed, 3 bits>, shape=(2,)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer tensor constant indexing is supported\n" # noqa: E501 # pylint: disable=line-too-long
"return(%1)\n"
),
),
],
)
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):

View File

@@ -0,0 +1,598 @@
"""Test module for constant indexing."""
import numpy as np
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function_into_op_graph
@pytest.mark.parametrize(
"input_value,function_with_indexing,output_value",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:3],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:3],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:3],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:3],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3::-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2::-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0::-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1::-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-3:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-2:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, :, :],
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, :, :],
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, 0, :],
EncryptedTensor(UnsignedInteger(1), shape=(3, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, :, 0],
EncryptedTensor(UnsignedInteger(1), shape=(3, 4)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0, :],
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, :, 0],
EncryptedTensor(UnsignedInteger(1), shape=(4,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, 0, 0],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0:, 1:, 2:],
EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[2:, 1:, 0:],
EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0],
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0],
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0, 0],
EncryptedScalar(UnsignedInteger(1)),
),
],
)
def test_constant_indexing(
default_compilation_configuration,
input_value,
function_with_indexing,
output_value,
):
"""Test compile_numpy_function_into_op_graph with constant indexing"""
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
opgraph = compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(opgraph.output_nodes) == 1
output_node = opgraph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]
@pytest.mark.parametrize(
"input_value,function_with_indexing,expected_error_type,expected_error_message",
[
pytest.param(
EncryptedScalar(UnsignedInteger(1)),
lambda x: x[0],
TypeError,
"Only tensors can be indexed "
"but you tried to index EncryptedScalar<Integer<unsigned, 1 bits>>",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0.5],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 0.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:5:0.5], # type: ignore
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1:5:0.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0, 1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [0, 1] "
"as the index has more elements than the number of dimensions of the tensor",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5] "
"because index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5:],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5:] "
"because start index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:10],
ValueError,
"Tensor of shape (3,) cannot be indexed with [:10] "
"because stop index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:0],
ValueError,
"Tensor of shape (3,) cannot be indexed with [2:0] "
"because start index is not less than stop index for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5::-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5::-1] "
"because start index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:10:-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [:10:-1] "
"because stop index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:2:-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [0:2:-1] "
"because step is negative and stop index is not less than start index for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[::0],
ValueError,
"Tensor of shape (3,) cannot be indexed with [::0] "
"because step is zero for dimension 0",
),
],
)
def test_invalid_constant_indexing(
default_compilation_configuration,
input_value,
function_with_indexing,
expected_error_type,
expected_error_message,
):
"""Test compile_numpy_function_into_op_graph with invalid constant indexing"""
with pytest.raises(expected_error_type):
try:
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
except Exception as error:
assert str(error) == expected_error_message
raise
@pytest.mark.parametrize(
"input_value,function_with_indexing,output_value",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.uint32(0)],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array(0)],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[slice(np.array(2), np.array(0), np.array(-1))],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
],
)
def test_constant_indexing_with_numpy_integers(
default_compilation_configuration,
input_value,
function_with_indexing,
output_value,
):
"""Test compile_numpy_function_into_op_graph with constant indexing with numpy integers"""
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
opgraph = compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(opgraph.output_nodes) == 1
output_node = opgraph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]
@pytest.mark.parametrize(
"input_value,function_with_indexing,expected_error_type,expected_error_message",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.float32(1.5)],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array(1.5)],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array([1, 2])],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use [1 2] for indexing",
),
],
)
def test_invalid_constant_indexing_with_numpy_values(
default_compilation_configuration,
input_value,
function_with_indexing,
expected_error_type,
expected_error_message,
):
"""Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values"""
with pytest.raises(expected_error_type):
try:
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
except Exception as error:
assert str(error) == expected_error_message
raise