mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(tracing): implement tracing of constant indexing
This commit is contained in:
@@ -16,6 +16,7 @@ from ..representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
IndexConstant,
|
||||
Input,
|
||||
Mul,
|
||||
Sub,
|
||||
@@ -29,6 +30,7 @@ IR_NODE_COLOR_MAPPING = {
|
||||
Sub: "yellow",
|
||||
Mul: "green",
|
||||
UnivariateFunction: "orange",
|
||||
IndexConstant: "black",
|
||||
Dot: "purple",
|
||||
"UnivariateFunction": "orange",
|
||||
"TLU": "grey",
|
||||
|
||||
@@ -6,7 +6,13 @@ import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
|
||||
from ..representation.intermediate import (
|
||||
Constant,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
UnivariateFunction,
|
||||
)
|
||||
|
||||
|
||||
def output_data_type_to_string(node):
|
||||
@@ -124,6 +130,9 @@ def get_printable_graph(
|
||||
what_to_print += prefix_to_add_to_what_to_print
|
||||
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
|
||||
what_to_print += suffix_to_add_to_what_to_print
|
||||
what_to_print += (
|
||||
f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else ""
|
||||
)
|
||||
what_to_print += ")"
|
||||
|
||||
# This code doesn't work with more than a single output
|
||||
|
||||
3
concrete/common/helpers/__init__.py
Normal file
3
concrete/common/helpers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Helpers for all kinds of tasks."""
|
||||
|
||||
from . import indexing_helpers
|
||||
277
concrete/common/helpers/indexing_helpers.py
Normal file
277
concrete/common/helpers/indexing_helpers.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Helpers for indexing functionality."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
def format_indexing_element(indexing_element: Union[int, slice]) -> str:
|
||||
"""Format an indexing element.
|
||||
|
||||
This is required mainly for slices. The reason is that string representation of slices
|
||||
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
|
||||
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
|
||||
it will be formatted as `[:, 2:]`.
|
||||
|
||||
Args:
|
||||
indexing_element (Union[int, slice]): indexing element to be formatted
|
||||
|
||||
Returns:
|
||||
str: formatted element
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if isinstance(indexing_element, slice):
|
||||
if indexing_element.start is not None:
|
||||
result += str(indexing_element.start)
|
||||
result += ":"
|
||||
if indexing_element.stop is not None:
|
||||
result += str(indexing_element.stop)
|
||||
if indexing_element.step is not None:
|
||||
result += ":"
|
||||
result += str(indexing_element.step)
|
||||
else:
|
||||
result += str(indexing_element)
|
||||
return result.replace("\n", " ")
|
||||
|
||||
|
||||
def validate_index(
|
||||
index: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
) -> Tuple[Union[int, slice], ...]:
|
||||
"""Make sure index is valid and convert it to the tuple form.
|
||||
|
||||
For example in `x[2]`, `index` is passed as `2`.
|
||||
To make it easier to work with, this function converts index to `(2,)`.
|
||||
|
||||
Args:
|
||||
index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve
|
||||
and return
|
||||
|
||||
Returns:
|
||||
Tuple[Union[int, slice], ...]: validated and improved index
|
||||
"""
|
||||
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
for indexing_element in index:
|
||||
valid = isinstance(indexing_element, (int, slice))
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
if (
|
||||
not (indexing_element.start is None or isinstance(indexing_element.start, int))
|
||||
or not (indexing_element.stop is None or isinstance(indexing_element.stop, int))
|
||||
or not (indexing_element.step is None or isinstance(indexing_element.step, int))
|
||||
):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
raise TypeError(
|
||||
f"Only integers and integer slices can be used for indexing "
|
||||
f"but you tried to use {format_indexing_element(indexing_element)} for indexing"
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def determine_output_shape(
|
||||
input_shape: Tuple[int, ...],
|
||||
index: Tuple[Union[int, slice], ...],
|
||||
) -> Tuple[int, ...]:
|
||||
"""Determine the output shape from the input shape and the index.
|
||||
|
||||
e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)`
|
||||
for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)`
|
||||
|
||||
Args:
|
||||
input_shape (Tuple[int, ...]): shape of the input tensor that is indexed
|
||||
index (Tuple[Union[int, slice], ...]): desired and validated index
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: shape of the result of indexing
|
||||
"""
|
||||
|
||||
indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index]
|
||||
index_str = f"[{', '.join(indexing_elements)}]"
|
||||
|
||||
if len(index) > len(input_shape):
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"as the index has more elements than the number of dimensions of the tensor"
|
||||
)
|
||||
|
||||
# indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :]
|
||||
# indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :]
|
||||
|
||||
# so let's replicate that behavior to make the rest of the code generic
|
||||
index += (slice(None, None, None),) * (len(input_shape) - len(index))
|
||||
|
||||
output_shape = []
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
if isinstance(indexing_element, int): # indexing removes the dimension
|
||||
indexing_element = (
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
if not 0 <= indexing_element < dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because index is out of range for dimension {dimension}"
|
||||
)
|
||||
elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension
|
||||
output_shape.append(
|
||||
determine_new_dimension_size(
|
||||
indexing_element,
|
||||
dimension_size,
|
||||
dimension,
|
||||
input_shape,
|
||||
index_str,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(output_shape)
|
||||
|
||||
|
||||
def sanitize_start_index(
|
||||
start: int,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Sanitize and check start index of a slice.
|
||||
|
||||
Args:
|
||||
start (int): start index being sanitized
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: sanitized start index
|
||||
"""
|
||||
|
||||
start = start if start >= 0 else start + dimension_size
|
||||
if not 0 <= start < dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because start index is out of range for dimension {dimension}"
|
||||
)
|
||||
return start
|
||||
|
||||
|
||||
def sanitize_stop_index(
|
||||
stop: int,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Sanitize and check stop index of a slice.
|
||||
|
||||
Args:
|
||||
stop (int): stop index being sanitized
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: sanitized stop index
|
||||
"""
|
||||
|
||||
stop = stop if stop >= 0 else stop + dimension_size
|
||||
if not 0 <= stop <= dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because stop index is out of range for dimension {dimension}"
|
||||
)
|
||||
return stop
|
||||
|
||||
|
||||
def determine_new_dimension_size(
|
||||
slice_: slice,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Determine the new size of a dimension from the old size and the slice applied to it.
|
||||
|
||||
e.g., for `slice_=1:4` and `dimension_size=5`, returns `3`
|
||||
for `slice_=::-1` and `dimension_size=5`, returns `5`
|
||||
|
||||
You may want to check this page to learn more about how this function works
|
||||
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing
|
||||
|
||||
Args:
|
||||
slice_ (slice): slice being applied to the dimension
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: new size of the dimension
|
||||
"""
|
||||
|
||||
step = slice_.step if slice_.step is not None else 1
|
||||
|
||||
if step > 0:
|
||||
start = slice_.start if slice_.start is not None else 0
|
||||
stop = slice_.stop if slice_.stop is not None else dimension_size
|
||||
|
||||
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
|
||||
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if start >= stop:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because start index is not less than stop index for dimension {dimension}"
|
||||
)
|
||||
|
||||
size_before_stepping = stop - start
|
||||
elif step < 0:
|
||||
start = slice_.start if slice_.start is not None else dimension_size - 1
|
||||
stop = slice_.stop
|
||||
|
||||
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if stop is None:
|
||||
# this is a weird case but it works as expected
|
||||
# the issue is that it's impossible to slice whole vector reversed
|
||||
# with a stop value different than none
|
||||
|
||||
# if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)`
|
||||
# here is what doesn't work (and this is expected it's just weird)
|
||||
#
|
||||
# ...
|
||||
# `x[:-2:-1].shape == (1,)`
|
||||
# `x[:-1:-1].shape == (0,)` (note that this is a hard error for us)
|
||||
# `x[:0:-1].shape == (5,)`
|
||||
# `x[:1:-1].shape == (4,)`
|
||||
# ...
|
||||
|
||||
size_before_stepping = start + 1
|
||||
else:
|
||||
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if stop >= start:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because step is negative and "
|
||||
f"stop index is not less than start index for dimension {dimension}"
|
||||
)
|
||||
|
||||
size_before_stepping = start - stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because step is zero for dimension {dimension}"
|
||||
)
|
||||
|
||||
quotient = size_before_stepping // abs(step)
|
||||
remainder = size_before_stepping % abs(step)
|
||||
|
||||
return quotient + (remainder != 0)
|
||||
@@ -71,6 +71,11 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
|
||||
return "only unsigned integer dot product is supported"
|
||||
|
||||
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
|
||||
assert_true(len(outputs) == 1)
|
||||
if not value_is_unsigned_integer(outputs[0]):
|
||||
return "only unsigned integer tensor constant indexing is supported"
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -14,7 +14,15 @@ from ..data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue
|
||||
from ..helpers import indexing_helpers
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
ClearTensor,
|
||||
EncryptedScalar,
|
||||
EncryptedTensor,
|
||||
TensorValue,
|
||||
)
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
|
||||
@@ -197,6 +205,56 @@ class Constant(IntermediateNode):
|
||||
return str(self.constant_data)
|
||||
|
||||
|
||||
class IndexConstant(IntermediateNode):
|
||||
"""Node representing a constant indexing in the program.
|
||||
|
||||
What we mean by constant indexing is that the index part of the operation is a constant.
|
||||
Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]`
|
||||
|
||||
The opposite is to have dynamic indexing, which this node does not support.
|
||||
Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]`
|
||||
"""
|
||||
|
||||
_n_in: int = 1
|
||||
|
||||
index: Tuple[Union[int, slice], ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_: BaseValue,
|
||||
index: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
) -> None:
|
||||
super().__init__((input_,))
|
||||
|
||||
if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar:
|
||||
raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}")
|
||||
|
||||
self.index = indexing_helpers.validate_index(index)
|
||||
|
||||
output_dtype = self.inputs[0].dtype
|
||||
output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index)
|
||||
|
||||
self.outputs = [
|
||||
EncryptedTensor(output_dtype, output_shape)
|
||||
if self.inputs[0].is_encrypted
|
||||
else ClearTensor(output_dtype, output_shape)
|
||||
]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0][self.index]
|
||||
|
||||
def label(self) -> str:
|
||||
"""Label of the node to show during drawings.
|
||||
|
||||
It can be used for some other places after `"value"` below is replaced by `""`.
|
||||
This note will no longer be necessary after #707 is addressed.
|
||||
|
||||
"""
|
||||
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
|
||||
index = ", ".join(elements)
|
||||
return f"value[{index}]"
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
"""Use a flooding algorithm to replace None values.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import (
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME,
|
||||
Add,
|
||||
IndexConstant,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
@@ -161,3 +162,7 @@ class BaseTracer(ABC):
|
||||
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __getitem__(self, item):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
59
concrete/numpy/np_indexing_helpers.py
Normal file
59
concrete/numpy/np_indexing_helpers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Helpers for indexing with numpy values functionality."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def should_sanitize(indexing_element: Any) -> bool:
|
||||
"""Decide whether to sanitize an indexing element or not.
|
||||
|
||||
Sanitizing in this context means converting supported numpy values into python values.
|
||||
|
||||
Args:
|
||||
indexing_element (Any): the indexing element to decide sanitization.
|
||||
|
||||
Returns:
|
||||
bool: True if indexing element should be sanitized otherwise False.
|
||||
"""
|
||||
|
||||
return isinstance(indexing_element, numpy.integer) or (
|
||||
isinstance(indexing_element, numpy.ndarray)
|
||||
and issubclass(indexing_element.dtype.type, numpy.integer)
|
||||
and indexing_element.shape == ()
|
||||
)
|
||||
|
||||
|
||||
def process_indexing_element(indexing_element: Any) -> Any:
|
||||
"""Process an indexing element.
|
||||
|
||||
Processing in this context means converting supported numpy values into python values.
|
||||
(if they are decided to be sanitized)
|
||||
|
||||
Args:
|
||||
indexing_element (Any): the indexing element to sanitize.
|
||||
|
||||
Returns:
|
||||
Any: the sanitized indexing element.
|
||||
"""
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
|
||||
start = indexing_element.start
|
||||
if should_sanitize(start):
|
||||
start = int(start)
|
||||
|
||||
stop = indexing_element.stop
|
||||
if should_sanitize(stop):
|
||||
stop = int(stop)
|
||||
|
||||
step = indexing_element.step
|
||||
if should_sanitize(step):
|
||||
step = int(step)
|
||||
|
||||
indexing_element = slice(start, stop, step)
|
||||
|
||||
elif should_sanitize(indexing_element):
|
||||
indexing_element = int(indexing_element)
|
||||
|
||||
return indexing_element
|
||||
@@ -18,6 +18,7 @@ from .np_dtypes_helpers import (
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_numpy_function_output_dtype,
|
||||
)
|
||||
from .np_indexing_helpers import process_indexing_element
|
||||
|
||||
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
|
||||
@@ -264,6 +265,14 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, tuple):
|
||||
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
|
||||
else:
|
||||
item = process_indexing_element(item)
|
||||
|
||||
return BaseTracer.__getitem__(self, item)
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
#
|
||||
|
||||
@@ -116,6 +116,54 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
20,
|
||||
id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
1,
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[0]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
numpy.array([2, 3]),
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[1:3]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
numpy.array([4, 3], dtype=numpy.int32),
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(
|
||||
EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1))
|
||||
),
|
||||
[
|
||||
numpy.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16],
|
||||
],
|
||||
dtype=numpy.int32,
|
||||
),
|
||||
],
|
||||
numpy.array(
|
||||
[
|
||||
[7, 6],
|
||||
[11, 10],
|
||||
],
|
||||
dtype=numpy.int32,
|
||||
),
|
||||
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
@@ -124,7 +172,10 @@ def test_evaluate(
|
||||
expected_result: int,
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
if isinstance(expected_result, numpy.ndarray):
|
||||
assert (node.evaluate(input_data) == expected_result).all()
|
||||
else:
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -14,6 +14,7 @@ from concrete.common.representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
@@ -147,6 +148,15 @@ def is_equivalent_input(lhs: Input, rhs: object) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool:
|
||||
"""Helper function to check if an IndexConstant node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, IndexConstant)
|
||||
and lhs.index == rhs.index
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
|
||||
"""Helper function to check if a Mul node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
@@ -171,6 +181,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
UnivariateFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Dot: is_equivalent_dot,
|
||||
IndexConstant: is_equivalent_index_constant,
|
||||
Input: is_equivalent_input,
|
||||
Mul: is_equivalent_mul,
|
||||
Sub: is_equivalent_sub,
|
||||
|
||||
@@ -852,6 +852,20 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))},
|
||||
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<signed, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = IndexConstant(%0[0]) # EncryptedTensor<Integer<signed, 3 bits>, shape=(2,)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer tensor constant indexing is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%1)\n"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
|
||||
|
||||
598
tests/numpy/test_compile_constant_indexing.py
Normal file
598
tests/numpy/test_compile_constant_indexing.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""Test module for constant indexing."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,output_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-3:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-2:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, :, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, :, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, 0, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, :, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, :, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, 0, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0:, 1:, 2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[2:, 1:, 0:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0, 0],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
output_value,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with constant indexing"""
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(opgraph.output_nodes) == 1
|
||||
output_node = opgraph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
lambda x: x[0],
|
||||
TypeError,
|
||||
"Only tensors can be indexed "
|
||||
"but you tried to index EncryptedScalar<Integer<unsigned, 1 bits>>",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0.5],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 0.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:5:0.5], # type: ignore
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1:5:0.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0, 1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [0, 1] "
|
||||
"as the index has more elements than the number of dimensions of the tensor",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5] "
|
||||
"because index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5:],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5:] "
|
||||
"because start index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:10],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [:10] "
|
||||
"because stop index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:0],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [2:0] "
|
||||
"because start index is not less than stop index for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5::-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5::-1] "
|
||||
"because start index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:10:-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [:10:-1] "
|
||||
"because stop index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:2:-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [0:2:-1] "
|
||||
"because step is negative and stop index is not less than start index for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[::0],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [::0] "
|
||||
"because step is zero for dimension 0",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_constant_indexing(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with invalid constant indexing"""
|
||||
|
||||
with pytest.raises(expected_error_type):
|
||||
try:
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
assert str(error) == expected_error_message
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,output_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.uint32(0)],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array(0)],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[slice(np.array(2), np.array(0), np.array(-1))],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_with_numpy_integers(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
output_value,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with constant indexing with numpy integers"""
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(opgraph.output_nodes) == 1
|
||||
output_node = opgraph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.float32(1.5)],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array(1.5)],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array([1, 2])],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use [1 2] for indexing",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_constant_indexing_with_numpy_values(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values"""
|
||||
|
||||
with pytest.raises(expected_error_type):
|
||||
try:
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
assert str(error) == expected_error_message
|
||||
raise
|
||||
Reference in New Issue
Block a user