diff --git a/.github/workflows/concrete_python_test_macos.yml b/.github/workflows/concrete_python_test_macos.yml index 5136208e0..276ab1018 100644 --- a/.github/workflows/concrete_python_test_macos.yml +++ b/.github/workflows/concrete_python_test_macos.yml @@ -113,9 +113,11 @@ jobs: find .testenv/lib/python3.10/site-packages -not \( -path .testenv/lib/python3.10/site-packages/concrete -prune \) -name 'lib*omp5.dylib' -or -name 'lib*omp.dylib' | xargs -n 1 ln -f -s $(pwd)/.testenv/lib/python3.10/site-packages/concrete/.dylibs/libomp.dylib + cp -R $GITHUB_WORKSPACE/frontends/concrete-python/examples ./examples cp -R $GITHUB_WORKSPACE/frontends/concrete-python/tests ./tests + cp $GITHUB_WORKSPACE/frontends/concrete-python/Makefile . - KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest + KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest-macos - name: Cleanup host if: success() || failure() diff --git a/docs/core-features/workarounds.md b/docs/core-features/workarounds.md index d16885406..076fbf7e8 100644 --- a/docs/core-features/workarounds.md +++ b/docs/core-features/workarounds.md @@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai) {% endhint %} -## Minimum/Maximum for multiple values - -Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s: - -```python -import numpy as np -from concrete import fhe - -@fhe.compiler({"args": "encrypted"}) -def fhe_min(args): - remaining = list(args) - while len(remaining) > 1: - a = remaining.pop() - b = remaining.pop() - remaining.insert(0, np.minimum(a, b)) - return remaining[0] - -inputset = [np.random.randint(0, 16, size=5) for _ in range(50)] -circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED) - -x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5) -assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5) -``` - ## Retrieving a value within an encrypted array with an encrypted index This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array: diff --git a/docs/dev/compatibility.md b/docs/dev/compatibility.md index a905102b2..053c3c4a3 100644 --- a/docs/dev/compatibility.md +++ b/docs/dev/compatibility.md @@ -110,7 +110,9 @@ Some operations are not supported between two encrypted values. If attempted, a * [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html) * [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html) * [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) +* [np.max](https://numpy.org/doc/stable/reference/generated/numpy.max.html) * [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html) +* [np.min](https://numpy.org/doc/stable/reference/generated/numpy.min.html) * [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html) * [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html) * [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html) diff --git a/frontends/concrete-python/.ruff.toml b/frontends/concrete-python/.ruff.toml index c43ae0d0f..1a697ad69 100644 --- a/frontends/concrete-python/.ruff.toml +++ b/frontends/concrete-python/.ruff.toml @@ -16,7 +16,7 @@ ignore = [ "**/__init__.py" = ["F401"] "concrete/fhe/compilation/configuration.py" = ["ARG002"] "concrete/fhe/mlir/processors/all.py" = ["F401"] -"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"] +"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002", "RUF012"] "concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"] "concrete/**" = ["RUF010"] "examples/**" = ["PLR2004", "RUF010"] diff --git a/frontends/concrete-python/Makefile b/frontends/concrete-python/Makefile index a6a5526f4..356b156d5 100644 --- a/frontends/concrete-python/Makefile +++ b/frontends/concrete-python/Makefile @@ -1,21 +1,21 @@ PYTHON=python PIP=$(PYTHON) -m pip +COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build +BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/ +TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/ + OS=undefined COVERAGE_OPT="" ifeq ($(shell uname), Linux) OS=linux COVERAGE_OPT="--cov=concrete.fhe --cov-fail-under=100 --cov-report=term-missing:skip-covered" + RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so else ifeq ($(shell uname), Darwin) OS=darwin + RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.dylib endif - -COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build -BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/ -RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so -TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/ - CONCRETE_VERSION?="" # empty mean latest # E.g. to use a previous version: `make CONCRETE_VERSION="<2.7.0" venv` # E.g. to use a nightly: `make CONCRETE_VERSION="==2.7.0dev20240801` @@ -76,6 +76,11 @@ pytest-default: tfhers-utils --key-cache "${KEY_CACHE_DIRECTORY}" \ -m "${PYTEST_MARKERS}" +pytest-macos: + pytest tests -svv -n auto \ + --key-cache "${KEY_CACHE_DIRECTORY}" \ + -m "${PYTEST_MARKERS}" + pytest-single: tfhers-utils eval $(shell make silent_cp_activate) # test single precision, mono params diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 6ce0e928a..ff30fc952 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -2197,7 +2197,11 @@ class Context: if cached_conversion is None: cached_conversion = Conversion( self.converting, - arith.ConstantOp(resulting_type, attribute, loc=self.location()), + arith.ConstantOp( # pylint: disable=too-many-function-args + resulting_type, + attribute, + loc=self.location(), + ), ) try: @@ -2813,6 +2817,25 @@ class Context: return self.to_signedness(result, of=resulting_type) + def min_max( + self, + resulting_type: ConversionType, + x: Conversion, + axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (), + keep_dims: bool = False, + *, + operation: str, + ): + # This import needs to happen here to avoid circular imports. + + # pylint: disable=import-outside-toplevel + + from .operations.min_max import min_max + + return min_max(self, resulting_type, x, axes, keep_dims, operation=operation) + + # pylint: enable=import-outside-toplevel + def minimum( self, resulting_type: ConversionType, @@ -3988,7 +4011,7 @@ class Context: def get_partition_name(self, partition: tfhers.CryptoParams) -> str: if partition not in self.tfhers_partition.keys(): - self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311 + self.tfhers_partition[partition] = f"tfhers_{randint(0, 2 ** 32)}" # noqa: S311 return self.tfhers_partition[partition] def change_partition( diff --git a/frontends/concrete-python/concrete/fhe/mlir/conversion.py b/frontends/concrete-python/concrete/fhe/mlir/conversion.py index 33fd4e16e..cf5a1eff2 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/conversion.py +++ b/frontends/concrete-python/concrete/fhe/mlir/conversion.py @@ -2,7 +2,7 @@ Declaration of `ConversionType` and `Conversion` classes. """ -# pylint: disable=import-error, +# pylint: disable=import-error,no-name-in-module import re from typing import Optional, Tuple @@ -12,7 +12,7 @@ from mlir.ir import Type as MlirType from ..representation import Node -# pylint: enable=import-error +# pylint: enable=import-error,no-name-in-module SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$") diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index b9b4215c7..d7daae819 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -326,6 +326,12 @@ class Converter: assert len(preds) == 2 return ctx.add(ctx.typeof(node), preds[0], preds[1]) + def amax(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + return self.max(ctx, node, preds) + + def amin(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + return self.min(ctx, node, preds) + def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) > 0 return ctx.array(ctx.typeof(node), elements=preds) @@ -530,6 +536,20 @@ class Converter: assert len(preds) == 2 return ctx.matmul(ctx.typeof(node), preds[0], preds[1]) + def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) == 1 + + if all(pred.is_encrypted for pred in preds): + return ctx.min_max( + ctx.typeof(node), + preds[0], + axes=node.properties["kwargs"].get("axis", ()), + keep_dims=node.properties["kwargs"].get("keepdims", False), + operation="max", + ) + + return self.tlu(ctx, node, preds) + def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 2 @@ -556,6 +576,20 @@ class Converter: ctx.error({node: "3-dimensional maxpooling is not supported at the moment"}) assert False, "unreachable" # pragma: no cover + def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) == 1 + + if all(pred.is_encrypted for pred in preds): + return ctx.min_max( + ctx.typeof(node), + preds[0], + axes=node.properties["kwargs"].get("axis", ()), + keep_dims=node.properties["kwargs"].get("keepdims", False), + operation="min", + ) + + return self.tlu(ctx, node, preds) + def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 2 diff --git a/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py b/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py index b84a01d7b..167fdb47c 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py +++ b/frontends/concrete-python/concrete/fhe/mlir/operations/indexing.py @@ -490,6 +490,7 @@ def fancy_indexing( resulting_type, x.result, indices.result, + original_bit_width=x.original_bit_width, ) @@ -640,6 +641,7 @@ def indexing( MlirDenseI64ArrayAttr.get(static_offsets), MlirDenseI64ArrayAttr.get(static_sizes), MlirDenseI64ArrayAttr.get(static_strides), + original_bit_width=x.original_bit_width, ) reassociaton = [] @@ -669,4 +671,5 @@ def indexing( for indices in reassociaton ], ), + original_bit_width=x.original_bit_width, ) diff --git a/frontends/concrete-python/concrete/fhe/mlir/operations/min_max.py b/frontends/concrete-python/concrete/fhe/mlir/operations/min_max.py new file mode 100644 index 000000000..b1a6b5482 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/mlir/operations/min_max.py @@ -0,0 +1,428 @@ +""" +Conversion of min and max operations. +""" + +from copy import deepcopy +from typing import Sequence, Set, Tuple, Union + +import numpy as np + +from ..context import Context +from ..conversion import Conversion, ConversionType + + +def min_max( + ctx: Context, + resulting_type: ConversionType, + x: Conversion, + axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (), + keep_dims: bool = False, + *, + operation: str, +) -> Conversion: + """ + Convert min or max operation. + + Args: + ctx (Context): + conversion context + + resulting_type (ConversionType): + resulting type of the operation + + x (Conversion): + input of the operation + + axes (Union[int, np.integer, Sequence[Union[int, np.integer]]], default = ()): + axes to reduce over + + keep_dims (bool, default = False): + whether to keep the reduced axes + + operation (str): + "min" or "max" + + Returns: + Conversion: + np.min or np.max on x depending on operation + """ + + # if the input is clear + if x.is_clear: + # raise error as computing min/max of clear values is not supported + highlights = { + x.origin: "value is clear", + ctx.converting: f"but computing {operation} of clear values is not supported", + } + ctx.error(highlights) + + # if the value is scalar + if x.is_scalar: + # return it as it's the min/max + return x + + # if axes is not specified, use all axes + if axes is None: + axes = [] + + # compute the list of unique axes to reduce + # empty list means reduce all axes + axes = list( + set( + # if axes was a single integer, only it will be reduced + [int(axes)] + if isinstance(axes, (int, np.integer)) + # if axes was a sequence, every axis in it will be reduced + else [int(axis) for axis in axes] + ) + ) + + # sanitize negative axis + # `axis=-1` is the same as `axis=(-1 + x.ndim))` + input_dimensions = len(x.shape) + for i, axis in enumerate(axes): + if axis < 0: + axes[i] += input_dimensions + assert 0 <= axes[i] < input_dimensions + + # if all axes are reduced + if len(axes) == 0 or len(axes) == len(x.shape): + # we flatten the input to use the `reduce` implementation + x = ctx.flatten(x) + + # if the input is a vector + if len(x.shape) == 1: + # we reduce the vector to its min/max value + result = reduce(ctx, ctx.element_typeof(resulting_type), x, operation=operation) + + # if the user wants to keep the reduced dimensions + if keep_dims: + # reshape the result into the resulting shape + result = ctx.reshape(result, shape=resulting_type.shape) + + # return the result + return result + + # if the reduce implementation is not used + # mock implementation will be used instead + # + # the idea is to let numpy compute the indices that will be compared to obtain the result + # + # for example, if input.shape == (2, 3) and axis == 1 + # we'll be computing + # [ + # [ + # {(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)} + # {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)} + # {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)} + # ] + # [ + # {(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)} + # {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)} + # {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)} + # ] + # ] + # + # this means to compute output[0, 0] we need to compare + # - input[0, 2, 0] + # - input[0, 0, 0] + # - input[0, 3, 0] + # - input[0, 1, 0] + # + # or to compute output[1, 2] we need to compare + # - input[1, 3, 2] + # - input[1, 2, 2] + # - input[1, 0, 2] + # - input[1, 1, 2] + # + # notice that number of comparisons (say n) are always the same in each output cell + # this opens up the possibility to use fancy indexing to extract n slices from the input + # and compare the slices in a tree-like fashion to obtain the result + # + # in the end, we'll be performing + # + # slice1 = [ + # [input[0, 2, 0], input[0, 2, 1], input[0, 1, 2]], + # [input[1, 0, 0], input[1, 2, 1], input[1, 3, 2]], + # ] + # slice2 = [ + # [input[0, 0, 0], input[0, 0, 1], input[0, 2, 2]], + # [input[1, 3, 0], input[1, 0, 1], input[1, 2, 2]], + # ] + # slice3 = [ + # [input[0, 3, 0], input[0, 3, 1], input[0, 3, 2]], + # [input[1, 1, 0], input[1, 1, 1], input[1, 0, 2]], + # ] + # slice4 = [ + # [input[0, 1, 0], input[0, 1, 1], input[0, 0, 2]], + # [input[1, 2, 0], input[1, 3, 1], input[1, 1, 2]], + # ] + # + # minimum_slice1_slice2 = np.minimum(slice1, slice2) + # minimum_slice3_slice4 = np.minimum(slice3, slice4) + # + # result = np.minimum(minimum_slice1_slice2, minimum_slice3_slice4) + # + # notice that all slices have the same shape as the result + + class Mock: + """ + Class to track accumulation of the operation. + """ + + # list of indices that have been accumulated + indices: Set[Tuple[int, ...]] + + # initialize the mock with a starting index + def __init__(self, index: Tuple[int, ...]): + self.indices = {index} + + # get the representation of the mock + def __repr__(self) -> str: + return f"{self.indices}" + + # combine the indices of the mock with another mock into a new mock + def combine(self, other: "Mock") -> "Mock": + result = deepcopy(self) + for index in other.indices: + result.indices.add(index) + return result + + # create the mock input + # + # [[[{(0, 0, 0)} {(0, 0, 1)} {(0, 0, 2)}] + # [{(0, 1, 0)} {(0, 1, 1)} {(0, 1, 2)}] + # [{(0, 2, 0)} {(0, 2, 1)} {(0, 2, 2)}] + # [{(0, 3, 0)} {(0, 3, 1)} {(0, 3, 2)}]] + # + # [[{(1, 0, 0)} {(1, 0, 1)} {(1, 0, 2)}] + # [{(1, 1, 0)} {(1, 1, 1)} {(1, 1, 2)}] + # [{(1, 2, 0)} {(1, 2, 1)} {(1, 2, 2)}] + # [{(1, 3, 0)} {(1, 3, 1)} {(1, 3, 2)}]]] + + mock_input = [] + for index in np.ndindex(x.shape): + mock_input.append(Mock(index)) + mock_input = np.array(mock_input).reshape(x.shape) + + # use numpy reduction to compute the mock output + # + # [[{(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)} + # {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)} + # {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)}] + # [{(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)} + # {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)} + # {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)}]] + + mock_output = np.frompyfunc(lambda mock1, mock2: mock1.combine(mock2), 2, 1).reduce( + mock_input, + axis=tuple(axes), + keepdims=keep_dims, + ) + + # extract a sample mock from the mock output + sample_mock = mock_output.flat[0] + + # extract the indices of the sample mock + sample_mock_indices = sample_mock.indices + + # compute number of comparisons + number_of_comparisons = len(sample_mock_indices) + + # extract a sample index from the sample mock indices + sample_mock_index = next(iter(sample_mock_indices)) + + # compute the number of indices + number_of_indices = len(sample_mock_index) + + # compute the shape of fancy indexing indices + index_shape = resulting_type.shape + + # compute the fancy indices to extract for comparison + to_compare = [] + for _ in range(number_of_comparisons): + indices = [] + for _ in range(number_of_indices): + index = np.zeros(index_shape, dtype=np.int64) # type: ignore + indices.append(index) + to_compare.append(tuple(indices)) + for position in np.ndindex(mock_output.shape): + mock_indices = list(mock_output[position].indices) + for i in range(number_of_comparisons): + for j in range(number_of_indices): + to_compare[i][j][position] = mock_indices[i][j] # type: ignore + + # to_compare will look like + # [ + # # for the first slice + # ( + # [[0, 0, 0], [1, 1, 1]], # i + # [[2, 2, 1], [0, 2, 3]], # j + # [[0, 1, 2], [0, 1, 2]], # k + # ), + # # for the second slice + # ( + # [[0, 0, 0], [1, 1, 1]], # i + # [[0, 0, 2], [3, 0, 2]], # j + # [[0, 1, 2], [0, 1, 2]], # k + # ), + # # for the third slice + # ( + # [[0, 0, 0], [1, 1, 1]], # i + # [[3, 3, 3], [1, 1, 0]], # j + # [[0, 1, 2], [0, 1, 2]], # k + # ), + # # for the fourth slice + # ( + # [[0, 0, 0], [1, 1, 1]], # i + # [[1, 1, 0], [2, 3, 1]], # j + # [[0, 1, 2], [0, 1, 2]], # k + # ), + # ] + + # find the type of the slices + slices_type = ctx.tensor(ctx.element_typeof(x), shape=resulting_type.shape) + + # extract the slices + slices = [] + for index in to_compare: + slices.append(ctx.index(slices_type, x, index)) # type: ignore + + # while there are more than 1 slices + while len(slices) > 1: + # pop the last two slices + a = slices.pop() + b = slices.pop() + + # compare the last two slices + if operation == "min": + c = ctx.minimum(resulting_type, a, b) + else: + c = ctx.maximum(resulting_type, a, b) + + # we need to set the original bit width manually + # as minimum/maximum doesn't constraint their output bit width. + c.set_original_bit_width(x.original_bit_width) + + # insert the slice back at the beginning of the slice queue + slices.insert(0, c) + + # return the result + return slices[0] + + +def reduce( + ctx: Context, + resulting_type: ConversionType, + values: Conversion, + *, + operation: str, +) -> Conversion: + """ + Reduce a vector of values to its min/max value. + """ + + # make sure the operation is valid + assert operation in {"min", "max"} + + # make sure the value is valid + assert values.is_tensor + assert len(values.shape) == 1 + assert values.is_encrypted + + # make sure the resulting type is valid + assert resulting_type.is_scalar + assert resulting_type.is_encrypted + + # let's say the vector was [1, 4, 2, 3, 0] + # and we're computing np.min(vector) + + # find the element type of the vector = fhe.uint3 + values_element_type = ctx.element_typeof(values) + + # find the middle of the vector = 2 + middle = values.shape[0] // 2 + + # we'll be splitting the array into two halves + # [1, 4] and [2, 3] in our case + # then we'll compute np.minimum(first_half, second_half) + # [1, 3] in our case + # if the reduction is a scalar, we'll return it + # otherwise, we'll reduce recursively until we obtain a scalar + # 1 in our case + + # find the half type of the vector which is + # fhe.tensor[fhe.uint3, 2] in the first iteration + # fhe.uint3 in the last iteration + half_type = ( + ctx.tensor(values_element_type, shape=(middle,)) if middle != 1 else values_element_type + ) + + # find the accumulated type of the vector which is + # fhe.tensor[fhe.uint3, 2] in the first iteration + # fhe.uint3 in the last iteration + accumulated_type = ( + ctx.tensor(resulting_type, shape=(middle,)) if middle != 1 else resulting_type + ) + + # if there is only one element in each half (e.g., vector = [1, 3], halfs = [1], [3]) + if middle == 1: + # extract the first element in the vector + first_half = ctx.index(half_type, values, index=[0]) + # extract the second element in the vector + second_half = ctx.index(half_type, values, index=[1]) + else: + # extract the elements from 0 to middle as the first half + first_half = ctx.index(half_type, values, index=[slice(0, middle)]) + # extract the elements from middle to 2*middle as the second half + second_half = ctx.index(half_type, values, index=[slice(middle, 2 * middle)]) + + # compare halfs + if operation == "min": + # [1, 3] in the first iteration + # 1 in the last iteration + reduced = ctx.minimum(accumulated_type, first_half, second_half) + else: + reduced = ctx.maximum(accumulated_type, first_half, second_half) + + # set the original bit width of the reduced so the following operation work as intended + # this is required here since ctx.minimum and ctx.maximum does not constraint output bit width + reduced.set_original_bit_width(values.original_bit_width) + + result = ( + # if reduced value is a scalar, we end the recursion + reduced + if reduced.is_scalar + # otherwise, we reduce the result of comparison of halfs + else reduce(ctx, resulting_type, reduced, operation=operation) + ) + + # if we have one more element that wasn't in the halfs + if values.shape[0] % 2 == 1: + # we extract it + last_value = ctx.index(values_element_type, values, index=[-1]) + # and compare it with the result we obtained from the halfs + result = ( + ctx.minimum(resulting_type, result, last_value) + if operation == "min" + else ctx.maximum(resulting_type, result, last_value) + ) + # again, we need to set the original bit width + result.set_original_bit_width(values.original_bit_width) + + # here is the visualization of the algorithm + # + # [ 1, 4, 2, 3, 0 ] + # + # [1, 4][2, 3],0 + # \ / | + # [1, 3] | + # \ / + # 1 / + # \ / + # 0 + # + # it has O(log(n)) - 1 number of tensor comparisons (of sizes n/2, n/4, ...) + # and up to 1 + O(log(n)) number of scalar comparisons (depending on the oddity of n/2, n/4, ..) + + return result diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index 6b5e9de7e..b3a78da64 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -351,7 +351,7 @@ class AdditionalConstraints: node.properties["strategy"] = strategy break - def min_max(self, node: Node, preds: List[Node]): + def minimum_maximum(self, node: Node, preds: List[Node]): assert len(preds) == 2 x = preds[0] @@ -392,69 +392,83 @@ class AdditionalConstraints: node.properties["strategy"] = strategy break + def min_max(self, node: Node, preds: List[Node]): + assert len(preds) == 1 + self.minimum_maximum(node, [preds[0], preds[0]]) + # ========== # Operations # ========== - add = { # noqa: RUF012 + add = { inputs_and_output_share_precision, } - array = { # noqa: RUF012 + amax = { + min_max, inputs_and_output_share_precision, } - assign_dynamic = { # noqa: RUF012 + amin = { + min_max, inputs_and_output_share_precision, } - assign_static = { # noqa: RUF012 + array = { inputs_and_output_share_precision, } - bitwise_and = { # noqa: RUF012 + assign_dynamic = { + inputs_and_output_share_precision, + } + + assign_static = { + inputs_and_output_share_precision, + } + + bitwise_and = { all_inputs_are_encrypted: { bitwise, }, } - bitwise_or = { # noqa: RUF012 + bitwise_or = { all_inputs_are_encrypted: { bitwise, }, } - bitwise_xor = { # noqa: RUF012 + bitwise_xor = { all_inputs_are_encrypted: { bitwise, }, } - broadcast_to = { # noqa: RUF012 + broadcast_to = { inputs_and_output_share_precision, } - concatenate = { # noqa: RUF012 + concatenate = { inputs_and_output_share_precision, } - conv1d = { # noqa: RUF012 + conv1d = { inputs_and_output_share_precision, } - conv2d = { # noqa: RUF012 + conv2d = { inputs_and_output_share_precision, } - conv3d = { # noqa: RUF012 + conv3d = { inputs_and_output_share_precision, } - copy = { # noqa: RUF012 + copy = { inputs_and_output_share_precision, } - dot = { # noqa: RUF012 + dot = { all_inputs_are_encrypted: { inputs_share_precision, inputs_require_one_more_bit, @@ -464,51 +478,51 @@ class AdditionalConstraints: }, } - equal = { # noqa: RUF012 + equal = { all_inputs_are_encrypted: { comparison, }, } - expand_dims = { # noqa: RUF012 + expand_dims = { inputs_and_output_share_precision, } - greater = { # noqa: RUF012 + greater = { all_inputs_are_encrypted: { comparison, }, } - greater_equal = { # noqa: RUF012 + greater_equal = { all_inputs_are_encrypted: { comparison, }, } - index_static = { # noqa: RUF012 + index_static = { inputs_and_output_share_precision, } - left_shift = { # noqa: RUF012 + left_shift = { all_inputs_are_encrypted: { bitwise, }, } - less = { # noqa: RUF012 + less = { all_inputs_are_encrypted: { comparison, }, } - less_equal = { # noqa: RUF012 + less_equal = { all_inputs_are_encrypted: { comparison, }, } - matmul = { # noqa: RUF012 + matmul = { all_inputs_are_encrypted: { inputs_share_precision, inputs_require_one_more_bit, @@ -518,34 +532,44 @@ class AdditionalConstraints: }, } - maximum = { # noqa: RUF012 + max = { + min_max, + inputs_and_output_share_precision, + } + + maximum = { all_inputs_are_encrypted: { - min_max, + minimum_maximum, }, } - maxpool1d = { # noqa: RUF012 + maxpool1d = { inputs_and_output_share_precision, inputs_require_one_more_bit, } - maxpool2d = { # noqa: RUF012 + maxpool2d = { inputs_and_output_share_precision, inputs_require_one_more_bit, } - maxpool3d = { # noqa: RUF012 + maxpool3d = { inputs_and_output_share_precision, inputs_require_one_more_bit, } - minimum = { # noqa: RUF012 + min = { + min_max, + inputs_and_output_share_precision, + } + + minimum = { all_inputs_are_encrypted: { - min_max, + minimum_maximum, }, } - multiply = { # noqa: RUF012 + multiply = { all_inputs_are_encrypted: { inputs_share_precision, inputs_require_one_more_bit, @@ -555,48 +579,48 @@ class AdditionalConstraints: }, } - negative = { # noqa: RUF012 + negative = { inputs_and_output_share_precision, } - not_equal = { # noqa: RUF012 + not_equal = { all_inputs_are_encrypted: { comparison, }, } - reshape = { # noqa: RUF012 + reshape = { inputs_and_output_share_precision, } - right_shift = { # noqa: RUF012 + right_shift = { all_inputs_are_encrypted: { bitwise, }, } - round_bit_pattern = { # noqa: RUF012 + round_bit_pattern = { has_overflow_protection: { inputs_and_output_share_precision, }, } - subtract = { # noqa: RUF012 + subtract = { inputs_and_output_share_precision, } - sum = { # noqa: RUF012 + sum = { inputs_and_output_share_precision, } - squeeze = { # noqa: RUF012 + squeeze = { inputs_and_output_share_precision, } - transpose = { # noqa: RUF012 + transpose = { inputs_and_output_share_precision, } - truncate_bit_pattern = { # noqa: RUF012 + truncate_bit_pattern = { inputs_and_output_share_precision, } diff --git a/frontends/concrete-python/concrete/fhe/mlir/utils.py b/frontends/concrete-python/concrete/fhe/mlir/utils.py index ddc7ec537..ae2f21034 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/utils.py +++ b/frontends/concrete-python/concrete/fhe/mlir/utils.py @@ -2,7 +2,7 @@ Declaration of various functions and constants related to MLIR conversion. """ -# pylint: disable=import-error +# pylint: disable=import-error,no-name-in-module from collections import defaultdict, deque from copy import deepcopy @@ -26,7 +26,7 @@ from ..dtypes import Integer from ..internal.utils import assert_that from ..representation import Node, Operation -# pylint: enable=import-error +# pylint: enable=import-error,no-name-in-module class HashableNdarray: diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index d5f5e8616..a683ad02d 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -705,18 +705,27 @@ class Graph: bounds of each node in the `Graph` """ - for node in self.graph.nodes(): + for node in self.query_nodes(ordered=True): if node in bounds: min_bound = bounds[node]["min"] max_bound = bounds[node]["max"] - node.bounds = (min_bound, max_bound) + node.bounds = (min_bound, max_bound) # type: ignore new_value = deepcopy(node.output) if isinstance(min_bound, (np.integer, int)): assert isinstance(new_value.dtype, Integer) new_value.dtype.update_to_represent(np.array([min_bound, max_bound])) + + if node.operation == Operation.Generic and node.properties["name"] in { + "amin", + "amax", + "min", + "max", + }: + assert isinstance(node.inputs[0].dtype, Integer) + new_value.dtype.is_signed = node.inputs[0].dtype.is_signed else: new_value.dtype = { np.bool_: UnsignedInteger(1), diff --git a/frontends/concrete-python/concrete/fhe/tracing/tracer.py b/frontends/concrete-python/concrete/fhe/tracing/tracer.py index 3054e68f9..a867114a4 100644 --- a/frontends/concrete-python/concrete/fhe/tracing/tracer.py +++ b/frontends/concrete-python/concrete/fhe/tracing/tracer.py @@ -281,7 +281,9 @@ class Tracer: np.logical_or, np.logical_xor, np.matmul, + np.max, np.maximum, + np.min, np.minimum, np.mod, np.multiply, @@ -331,6 +333,14 @@ class Tracer: np.expand_dims: { "axis", }, + np.max: { + "axis", + "keepdims", + }, + np.min: { + "axis", + "keepdims", + }, np.ones_like: { "dtype", }, @@ -471,6 +481,12 @@ class Tracer: sanitized_args = [self.sanitize(args[0])] if len(args) > 1: kwargs["shape"] = args[1] + elif func in {np.min, np.max}: + sanitized_args = [self.sanitize(args[0])] + for i, keyword in enumerate(["axis", "out", "keepdims", "initial", "where"]): + position = i + 1 + if len(args) > position: + kwargs[keyword] = args[position] elif func is np.reshape: sanitized_args = [self.sanitize(args[0])] if len(args) > 1: diff --git a/frontends/concrete-python/scripts/links/check_headers.py b/frontends/concrete-python/scripts/links/check_headers.py index 04ba65cb1..074ca2b69 100644 --- a/frontends/concrete-python/scripts/links/check_headers.py +++ b/frontends/concrete-python/scripts/links/check_headers.py @@ -19,7 +19,7 @@ def ast_iterator(root): while nodes: current_node = nodes.pop(0) yield current_node - if hasattr(current_node, "children"): + if hasattr(current_node, "children") and current_node.children is not None: nodes += current_node.children diff --git a/frontends/concrete-python/tests/conftest.py b/frontends/concrete-python/tests/conftest.py index 9911ff413..cbe9ad042 100644 --- a/frontends/concrete-python/tests/conftest.py +++ b/frontends/concrete-python/tests/conftest.py @@ -318,6 +318,10 @@ class Helpers: if i == retries - 1: message = f""" + Sample + =============== + {sample} + Expected Output =============== {expected} diff --git a/frontends/concrete-python/tests/execution/test_min_max.py b/frontends/concrete-python/tests/execution/test_min_max.py index 17fc0b3ba..c89dff2df 100644 --- a/frontends/concrete-python/tests/execution/test_min_max.py +++ b/frontends/concrete-python/tests/execution/test_min_max.py @@ -1,5 +1,5 @@ """ -Tests of execution of minimum and maximum operations. +Tests of execution of min and max operations. """ import random @@ -11,219 +11,145 @@ from concrete import fhe from concrete.fhe.dtypes import Integer from concrete.fhe.values import ValueDescription -cases = [ - [ - # operation - ( - "minimum_optimized_x", - lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore - ), - # bit widths - 4, - 4, - # signednesses - False, - True, - # shapes - (), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], - [ - # operation - ( - "minimum_optimized_x", - lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore - ), - # bit widths - 4, - 4, - # signednesses - True, - False, - # shapes - (2,), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], - [ - # operation - ( - "maximum_optimized_y", - lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore - ), - # bit widths - 4, - 3, - # signednesses - True, - False, - # shapes - (), - (2, 3), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], - [ - # operation - ( - "maximum_optimized_y", - lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore - ), - # bit widths - 4, - 3, - # signednesses - False, - True, - # shapes - (), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], -] -cases += [ - [ - # operation - operation, - # bit widths - 1, - 1, - # signednesses - lhs_is_signed, - rhs_is_signed, - # shapes - (), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ] - for lhs_is_signed in [False, True] - for rhs_is_signed in [False, True] - for operation in [ - ( - "maximum", - lambda x, y: np.maximum(x, y), - ), - ] -] -cases = [ - [ - # operation - ("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)), - # bit widths - 7, - 7, - # signednesses - True, - False, - # shapes - (), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], - [ - # operation - ("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)), - # bit widths - 7, - 7, - # signednesses - False, - True, - # shapes - (), - (), - # strategy - fhe.MinMaxStrategy.CHUNKED, - ], -] -for lhs_bit_width in range(1, 5): - for rhs_bit_width in range(1, 5): - strategies = [] - if lhs_bit_width <= 3 and rhs_bit_width <= 3: - strategies += [ - fhe.MinMaxStrategy.ONE_TLU_PROMOTED, - fhe.MinMaxStrategy.THREE_TLU_CASTED, - ] - else: - strategies += [ - fhe.MinMaxStrategy.CHUNKED, - ] - - for lhs_is_signed in [False, True]: - for rhs_is_signed in [False, True]: - cases += [ - [ - # operation - operation, - # bit widths - lhs_bit_width, - rhs_bit_width, - # signednesses - lhs_is_signed, - rhs_is_signed, - # shapes - random.choice([(), (2,), (3, 2)]), - random.choice([(), (2,), (3, 2)]), - # strategy - strategy, - ] - for operation in [ - ("minimum", lambda x, y: np.minimum(x, y)), - ("maximum", lambda x, y: np.maximum(x, y)), - ] - for strategy in strategies - ] +cases = [] +for operation in [("max", lambda x: np.max(x)), ("min", lambda x: np.min(x))]: + for bit_width in range(1, 5): + for is_signed in [False, True]: + for shape in [(), (4,), (3, 3)]: + for keepdims in [False, True]: + for strategy in [ + fhe.MinMaxStrategy.ONE_TLU_PROMOTED, + fhe.MinMaxStrategy.THREE_TLU_CASTED, + fhe.MinMaxStrategy.CHUNKED, + ]: + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + None, + keepdims, + strategy, + ], + ) + for axis in range(len(shape)): + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + axis, + keepdims, + strategy, + ], + ) + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + -1, + keepdims, + strategy, + ], + ) + if len(shape) == 2: + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + (0, 1), + keepdims, + strategy, + ], + ) + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + -2, + keepdims, + strategy, + ], + ) + if len(shape) == 3: + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + (0, 1), + keepdims, + strategy, + ], + ) + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + (0, 2), + keepdims, + strategy, + ], + ) + cases.append( + [ + operation, + bit_width, + is_signed, + shape, + (1, 2), + keepdims, + strategy, + ], + ) # pylint: disable=redefined-outer-name @pytest.mark.parametrize( - "operation," - "lhs_bit_width,rhs_bit_width," - "lhs_is_signed,rhs_is_signed," - "lhs_shape,rhs_shape," - "strategy", - cases, + "operation,bit_width,is_signed,shape,axis,keepdims,strategy", + random.sample(cases, 100), ) -def test_minimum_maximum( +def test_min_max( operation, - lhs_bit_width, - rhs_bit_width, - lhs_is_signed, - rhs_is_signed, - lhs_shape, - rhs_shape, + bit_width, + is_signed, + shape, + axis, + keepdims, strategy, helpers, ): """ - Test comparison operations between encrypted integers. + Test np.min/np.max on encrypted values. """ name, function = operation - lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width) - rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width) - - lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True) - rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True) + dtype = Integer(is_signed=is_signed, bit_width=bit_width) + description = ValueDescription(dtype, shape=shape, is_encrypted=True) print() print() print( - f"{name}({lhs_description}, {rhs_description})" + f"np.{name}({description}, axis={axis}, keepdims={keepdims})" + (f" {{{strategy}}}" if strategy is not None else "") ) print() print() - parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"} + parameter_encryption_statuses = {"x": "encrypted"} configuration = helpers.configuration() if strategy is not None: @@ -231,59 +157,17 @@ def test_minimum_maximum( compiler = fhe.Compiler(function, parameter_encryption_statuses) - inputset = [ - ( - np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape), - np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape), - ) - for _ in range(100) - ] + inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)] circuit = compiler.compile(inputset, configuration) samples = [ - [ - np.zeros(lhs_shape, dtype=np.int64), - np.zeros(rhs_shape, dtype=np.int64), - ], - [ - np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(), - np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(), - ], - [ - np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(), - np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(), - ], - [ - np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(), - np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(), - ], - [ - np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape), - np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape), - ], + np.zeros(shape, dtype=np.int64), + np.ones(shape, dtype=np.int64) * dtype.min(), + np.ones(shape, dtype=np.int64) * dtype.max(), + np.random.randint(dtype.min(), dtype.max() + 1, size=shape), + np.random.randint(dtype.min(), dtype.max() + 1, size=shape), + np.random.randint(dtype.min(), dtype.max() + 1, size=shape), ] for sample in samples: helpers.check_execution(circuit, function, sample, retries=5) - - -def test_internal_signed_tlu_padding(helpers): - """Test that the signed input LUT is correctly padded in the case of substraction trick.""" - - inputset = [(i, j) for i in [0, 1] for j in [0, 1]] - - @fhe.compiler({"a": "encrypted", "b": "encrypted"}) - def min2(a, b): - min_12 = np.minimum(a, b) - return (min_12, a + 3, b + 3) - - c = min2.compile(inputset, helpers.configuration()) - min_0_1, _, _ = c.encrypt_run_decrypt(0, 1) - - assert min_0_1 == 0 - - # Some extra checks to verify that the test is relevant (substraction trick). - assert c.mlir.count("to_signed") == 2 # check substraction trick is used - assert c.mlir.count("sub_eint") == 1 # check substraction trick is used - assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end - assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle diff --git a/frontends/concrete-python/tests/execution/test_minimum_maximum.py b/frontends/concrete-python/tests/execution/test_minimum_maximum.py new file mode 100644 index 000000000..17fc0b3ba --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_minimum_maximum.py @@ -0,0 +1,289 @@ +""" +Tests of execution of minimum and maximum operations. +""" + +import random + +import numpy as np +import pytest + +from concrete import fhe +from concrete.fhe.dtypes import Integer +from concrete.fhe.values import ValueDescription + +cases = [ + [ + # operation + ( + "minimum_optimized_x", + lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore + ), + # bit widths + 4, + 4, + # signednesses + False, + True, + # shapes + (), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], + [ + # operation + ( + "minimum_optimized_x", + lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore + ), + # bit widths + 4, + 4, + # signednesses + True, + False, + # shapes + (2,), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], + [ + # operation + ( + "maximum_optimized_y", + lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore + ), + # bit widths + 4, + 3, + # signednesses + True, + False, + # shapes + (), + (2, 3), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], + [ + # operation + ( + "maximum_optimized_y", + lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore + ), + # bit widths + 4, + 3, + # signednesses + False, + True, + # shapes + (), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], +] +cases += [ + [ + # operation + operation, + # bit widths + 1, + 1, + # signednesses + lhs_is_signed, + rhs_is_signed, + # shapes + (), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ] + for lhs_is_signed in [False, True] + for rhs_is_signed in [False, True] + for operation in [ + ( + "maximum", + lambda x, y: np.maximum(x, y), + ), + ] +] +cases = [ + [ + # operation + ("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)), + # bit widths + 7, + 7, + # signednesses + True, + False, + # shapes + (), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], + [ + # operation + ("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)), + # bit widths + 7, + 7, + # signednesses + False, + True, + # shapes + (), + (), + # strategy + fhe.MinMaxStrategy.CHUNKED, + ], +] +for lhs_bit_width in range(1, 5): + for rhs_bit_width in range(1, 5): + strategies = [] + if lhs_bit_width <= 3 and rhs_bit_width <= 3: + strategies += [ + fhe.MinMaxStrategy.ONE_TLU_PROMOTED, + fhe.MinMaxStrategy.THREE_TLU_CASTED, + ] + else: + strategies += [ + fhe.MinMaxStrategy.CHUNKED, + ] + + for lhs_is_signed in [False, True]: + for rhs_is_signed in [False, True]: + cases += [ + [ + # operation + operation, + # bit widths + lhs_bit_width, + rhs_bit_width, + # signednesses + lhs_is_signed, + rhs_is_signed, + # shapes + random.choice([(), (2,), (3, 2)]), + random.choice([(), (2,), (3, 2)]), + # strategy + strategy, + ] + for operation in [ + ("minimum", lambda x, y: np.minimum(x, y)), + ("maximum", lambda x, y: np.maximum(x, y)), + ] + for strategy in strategies + ] + +# pylint: disable=redefined-outer-name + + +@pytest.mark.parametrize( + "operation," + "lhs_bit_width,rhs_bit_width," + "lhs_is_signed,rhs_is_signed," + "lhs_shape,rhs_shape," + "strategy", + cases, +) +def test_minimum_maximum( + operation, + lhs_bit_width, + rhs_bit_width, + lhs_is_signed, + rhs_is_signed, + lhs_shape, + rhs_shape, + strategy, + helpers, +): + """ + Test comparison operations between encrypted integers. + """ + + name, function = operation + + lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width) + rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width) + + lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True) + rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True) + + print() + print() + print( + f"{name}({lhs_description}, {rhs_description})" + + (f" {{{strategy}}}" if strategy is not None else "") + ) + print() + print() + + parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"} + configuration = helpers.configuration() + + if strategy is not None: + configuration = configuration.fork(min_max_strategy_preference=[strategy]) + + compiler = fhe.Compiler(function, parameter_encryption_statuses) + + inputset = [ + ( + np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape), + np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape), + ) + for _ in range(100) + ] + + circuit = compiler.compile(inputset, configuration) + + samples = [ + [ + np.zeros(lhs_shape, dtype=np.int64), + np.zeros(rhs_shape, dtype=np.int64), + ], + [ + np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(), + np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(), + ], + [ + np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(), + np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(), + ], + [ + np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(), + np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(), + ], + [ + np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape), + np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape), + ], + ] + for sample in samples: + helpers.check_execution(circuit, function, sample, retries=5) + + +def test_internal_signed_tlu_padding(helpers): + """Test that the signed input LUT is correctly padded in the case of substraction trick.""" + + inputset = [(i, j) for i in [0, 1] for j in [0, 1]] + + @fhe.compiler({"a": "encrypted", "b": "encrypted"}) + def min2(a, b): + min_12 = np.minimum(a, b) + return (min_12, a + 3, b + 3) + + c = min2.compile(inputset, helpers.configuration()) + min_0_1, _, _ = c.encrypt_run_decrypt(0, 1) + + assert min_0_1 == 0 + + # Some extra checks to verify that the test is relevant (substraction trick). + assert c.mlir.count("to_signed") == 2 # check substraction trick is used + assert c.mlir.count("sub_eint") == 1 # check substraction trick is used + assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end + assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle