mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-10 05:18:00 -05:00
feat(frontend): add support for np.min and np.max
This commit is contained in:
@@ -113,9 +113,11 @@ jobs:
|
||||
|
||||
find .testenv/lib/python3.10/site-packages -not \( -path .testenv/lib/python3.10/site-packages/concrete -prune \) -name 'lib*omp5.dylib' -or -name 'lib*omp.dylib' | xargs -n 1 ln -f -s $(pwd)/.testenv/lib/python3.10/site-packages/concrete/.dylibs/libomp.dylib
|
||||
|
||||
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/examples ./examples
|
||||
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/tests ./tests
|
||||
|
||||
cp $GITHUB_WORKSPACE/frontends/concrete-python/Makefile .
|
||||
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest
|
||||
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest-macos
|
||||
|
||||
- name: Cleanup host
|
||||
if: success() || failure()
|
||||
|
||||
@@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu
|
||||
All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
|
||||
{% endhint %}
|
||||
|
||||
## Minimum/Maximum for multiple values
|
||||
|
||||
Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from concrete import fhe
|
||||
|
||||
@fhe.compiler({"args": "encrypted"})
|
||||
def fhe_min(args):
|
||||
remaining = list(args)
|
||||
while len(remaining) > 1:
|
||||
a = remaining.pop()
|
||||
b = remaining.pop()
|
||||
remaining.insert(0, np.minimum(a, b))
|
||||
return remaining[0]
|
||||
|
||||
inputset = [np.random.randint(0, 16, size=5) for _ in range(50)]
|
||||
circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED)
|
||||
|
||||
x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5)
|
||||
assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5)
|
||||
```
|
||||
|
||||
## Retrieving a value within an encrypted array with an encrypted index
|
||||
This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array:
|
||||
|
||||
|
||||
@@ -110,7 +110,9 @@ Some operations are not supported between two encrypted values. If attempted, a
|
||||
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
|
||||
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
|
||||
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
|
||||
* [np.max](https://numpy.org/doc/stable/reference/generated/numpy.max.html)
|
||||
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
|
||||
* [np.min](https://numpy.org/doc/stable/reference/generated/numpy.min.html)
|
||||
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
|
||||
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
|
||||
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)
|
||||
|
||||
@@ -16,7 +16,7 @@ ignore = [
|
||||
"**/__init__.py" = ["F401"]
|
||||
"concrete/fhe/compilation/configuration.py" = ["ARG002"]
|
||||
"concrete/fhe/mlir/processors/all.py" = ["F401"]
|
||||
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
|
||||
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002", "RUF012"]
|
||||
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
|
||||
"concrete/**" = ["RUF010"]
|
||||
"examples/**" = ["PLR2004", "RUF010"]
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
PYTHON=python
|
||||
PIP=$(PYTHON) -m pip
|
||||
|
||||
COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
|
||||
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
|
||||
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/
|
||||
|
||||
OS=undefined
|
||||
COVERAGE_OPT=""
|
||||
ifeq ($(shell uname), Linux)
|
||||
OS=linux
|
||||
COVERAGE_OPT="--cov=concrete.fhe --cov-fail-under=100 --cov-report=term-missing:skip-covered"
|
||||
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
|
||||
else ifeq ($(shell uname), Darwin)
|
||||
OS=darwin
|
||||
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.dylib
|
||||
endif
|
||||
|
||||
|
||||
COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
|
||||
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
|
||||
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
|
||||
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/
|
||||
|
||||
CONCRETE_VERSION?="" # empty mean latest
|
||||
# E.g. to use a previous version: `make CONCRETE_VERSION="<2.7.0" venv`
|
||||
# E.g. to use a nightly: `make CONCRETE_VERSION="==2.7.0dev20240801`
|
||||
@@ -76,6 +76,11 @@ pytest-default: tfhers-utils
|
||||
--key-cache "${KEY_CACHE_DIRECTORY}" \
|
||||
-m "${PYTEST_MARKERS}"
|
||||
|
||||
pytest-macos:
|
||||
pytest tests -svv -n auto \
|
||||
--key-cache "${KEY_CACHE_DIRECTORY}" \
|
||||
-m "${PYTEST_MARKERS}"
|
||||
|
||||
pytest-single: tfhers-utils
|
||||
eval $(shell make silent_cp_activate)
|
||||
# test single precision, mono params
|
||||
|
||||
@@ -2197,7 +2197,11 @@ class Context:
|
||||
if cached_conversion is None:
|
||||
cached_conversion = Conversion(
|
||||
self.converting,
|
||||
arith.ConstantOp(resulting_type, attribute, loc=self.location()),
|
||||
arith.ConstantOp( # pylint: disable=too-many-function-args
|
||||
resulting_type,
|
||||
attribute,
|
||||
loc=self.location(),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -2813,6 +2817,25 @@ class Context:
|
||||
|
||||
return self.to_signedness(result, of=resulting_type)
|
||||
|
||||
def min_max(
|
||||
self,
|
||||
resulting_type: ConversionType,
|
||||
x: Conversion,
|
||||
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
|
||||
keep_dims: bool = False,
|
||||
*,
|
||||
operation: str,
|
||||
):
|
||||
# This import needs to happen here to avoid circular imports.
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
from .operations.min_max import min_max
|
||||
|
||||
return min_max(self, resulting_type, x, axes, keep_dims, operation=operation)
|
||||
|
||||
# pylint: enable=import-outside-toplevel
|
||||
|
||||
def minimum(
|
||||
self,
|
||||
resulting_type: ConversionType,
|
||||
@@ -3988,7 +4011,7 @@ class Context:
|
||||
|
||||
def get_partition_name(self, partition: tfhers.CryptoParams) -> str:
|
||||
if partition not in self.tfhers_partition.keys():
|
||||
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311
|
||||
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2 ** 32)}" # noqa: S311
|
||||
return self.tfhers_partition[partition]
|
||||
|
||||
def change_partition(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Declaration of `ConversionType` and `Conversion` classes.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
@@ -12,7 +12,7 @@ from mlir.ir import Type as MlirType
|
||||
|
||||
from ..representation import Node
|
||||
|
||||
# pylint: enable=import-error
|
||||
# pylint: enable=import-error,no-name-in-module
|
||||
|
||||
|
||||
SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")
|
||||
|
||||
@@ -326,6 +326,12 @@ class Converter:
|
||||
assert len(preds) == 2
|
||||
return ctx.add(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def amax(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
return self.max(ctx, node, preds)
|
||||
|
||||
def amin(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
return self.min(ctx, node, preds)
|
||||
|
||||
def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) > 0
|
||||
return ctx.array(ctx.typeof(node), elements=preds)
|
||||
@@ -530,6 +536,20 @@ class Converter:
|
||||
assert len(preds) == 2
|
||||
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])
|
||||
|
||||
def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.min_max(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
axes=node.properties["kwargs"].get("axis", ()),
|
||||
keep_dims=node.properties["kwargs"].get("keepdims", False),
|
||||
operation="max",
|
||||
)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
@@ -556,6 +576,20 @@ class Converter:
|
||||
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
|
||||
assert False, "unreachable" # pragma: no cover
|
||||
|
||||
def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
|
||||
if all(pred.is_encrypted for pred in preds):
|
||||
return ctx.min_max(
|
||||
ctx.typeof(node),
|
||||
preds[0],
|
||||
axes=node.properties["kwargs"].get("axis", ()),
|
||||
keep_dims=node.properties["kwargs"].get("keepdims", False),
|
||||
operation="min",
|
||||
)
|
||||
|
||||
return self.tlu(ctx, node, preds)
|
||||
|
||||
def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 2
|
||||
|
||||
|
||||
@@ -490,6 +490,7 @@ def fancy_indexing(
|
||||
resulting_type,
|
||||
x.result,
|
||||
indices.result,
|
||||
original_bit_width=x.original_bit_width,
|
||||
)
|
||||
|
||||
|
||||
@@ -640,6 +641,7 @@ def indexing(
|
||||
MlirDenseI64ArrayAttr.get(static_offsets),
|
||||
MlirDenseI64ArrayAttr.get(static_sizes),
|
||||
MlirDenseI64ArrayAttr.get(static_strides),
|
||||
original_bit_width=x.original_bit_width,
|
||||
)
|
||||
|
||||
reassociaton = []
|
||||
@@ -669,4 +671,5 @@ def indexing(
|
||||
for indices in reassociaton
|
||||
],
|
||||
),
|
||||
original_bit_width=x.original_bit_width,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Conversion of min and max operations.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Sequence, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..context import Context
|
||||
from ..conversion import Conversion, ConversionType
|
||||
|
||||
|
||||
def min_max(
|
||||
ctx: Context,
|
||||
resulting_type: ConversionType,
|
||||
x: Conversion,
|
||||
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
|
||||
keep_dims: bool = False,
|
||||
*,
|
||||
operation: str,
|
||||
) -> Conversion:
|
||||
"""
|
||||
Convert min or max operation.
|
||||
|
||||
Args:
|
||||
ctx (Context):
|
||||
conversion context
|
||||
|
||||
resulting_type (ConversionType):
|
||||
resulting type of the operation
|
||||
|
||||
x (Conversion):
|
||||
input of the operation
|
||||
|
||||
axes (Union[int, np.integer, Sequence[Union[int, np.integer]]], default = ()):
|
||||
axes to reduce over
|
||||
|
||||
keep_dims (bool, default = False):
|
||||
whether to keep the reduced axes
|
||||
|
||||
operation (str):
|
||||
"min" or "max"
|
||||
|
||||
Returns:
|
||||
Conversion:
|
||||
np.min or np.max on x depending on operation
|
||||
"""
|
||||
|
||||
# if the input is clear
|
||||
if x.is_clear:
|
||||
# raise error as computing min/max of clear values is not supported
|
||||
highlights = {
|
||||
x.origin: "value is clear",
|
||||
ctx.converting: f"but computing {operation} of clear values is not supported",
|
||||
}
|
||||
ctx.error(highlights)
|
||||
|
||||
# if the value is scalar
|
||||
if x.is_scalar:
|
||||
# return it as it's the min/max
|
||||
return x
|
||||
|
||||
# if axes is not specified, use all axes
|
||||
if axes is None:
|
||||
axes = []
|
||||
|
||||
# compute the list of unique axes to reduce
|
||||
# empty list means reduce all axes
|
||||
axes = list(
|
||||
set(
|
||||
# if axes was a single integer, only it will be reduced
|
||||
[int(axes)]
|
||||
if isinstance(axes, (int, np.integer))
|
||||
# if axes was a sequence, every axis in it will be reduced
|
||||
else [int(axis) for axis in axes]
|
||||
)
|
||||
)
|
||||
|
||||
# sanitize negative axis
|
||||
# `axis=-1` is the same as `axis=(-1 + x.ndim))`
|
||||
input_dimensions = len(x.shape)
|
||||
for i, axis in enumerate(axes):
|
||||
if axis < 0:
|
||||
axes[i] += input_dimensions
|
||||
assert 0 <= axes[i] < input_dimensions
|
||||
|
||||
# if all axes are reduced
|
||||
if len(axes) == 0 or len(axes) == len(x.shape):
|
||||
# we flatten the input to use the `reduce` implementation
|
||||
x = ctx.flatten(x)
|
||||
|
||||
# if the input is a vector
|
||||
if len(x.shape) == 1:
|
||||
# we reduce the vector to its min/max value
|
||||
result = reduce(ctx, ctx.element_typeof(resulting_type), x, operation=operation)
|
||||
|
||||
# if the user wants to keep the reduced dimensions
|
||||
if keep_dims:
|
||||
# reshape the result into the resulting shape
|
||||
result = ctx.reshape(result, shape=resulting_type.shape)
|
||||
|
||||
# return the result
|
||||
return result
|
||||
|
||||
# if the reduce implementation is not used
|
||||
# mock implementation will be used instead
|
||||
#
|
||||
# the idea is to let numpy compute the indices that will be compared to obtain the result
|
||||
#
|
||||
# for example, if input.shape == (2, 3) and axis == 1
|
||||
# we'll be computing
|
||||
# [
|
||||
# [
|
||||
# {(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)}
|
||||
# {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)}
|
||||
# {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)}
|
||||
# ]
|
||||
# [
|
||||
# {(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)}
|
||||
# {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)}
|
||||
# {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)}
|
||||
# ]
|
||||
# ]
|
||||
#
|
||||
# this means to compute output[0, 0] we need to compare
|
||||
# - input[0, 2, 0]
|
||||
# - input[0, 0, 0]
|
||||
# - input[0, 3, 0]
|
||||
# - input[0, 1, 0]
|
||||
#
|
||||
# or to compute output[1, 2] we need to compare
|
||||
# - input[1, 3, 2]
|
||||
# - input[1, 2, 2]
|
||||
# - input[1, 0, 2]
|
||||
# - input[1, 1, 2]
|
||||
#
|
||||
# notice that number of comparisons (say n) are always the same in each output cell
|
||||
# this opens up the possibility to use fancy indexing to extract n slices from the input
|
||||
# and compare the slices in a tree-like fashion to obtain the result
|
||||
#
|
||||
# in the end, we'll be performing
|
||||
#
|
||||
# slice1 = [
|
||||
# [input[0, 2, 0], input[0, 2, 1], input[0, 1, 2]],
|
||||
# [input[1, 0, 0], input[1, 2, 1], input[1, 3, 2]],
|
||||
# ]
|
||||
# slice2 = [
|
||||
# [input[0, 0, 0], input[0, 0, 1], input[0, 2, 2]],
|
||||
# [input[1, 3, 0], input[1, 0, 1], input[1, 2, 2]],
|
||||
# ]
|
||||
# slice3 = [
|
||||
# [input[0, 3, 0], input[0, 3, 1], input[0, 3, 2]],
|
||||
# [input[1, 1, 0], input[1, 1, 1], input[1, 0, 2]],
|
||||
# ]
|
||||
# slice4 = [
|
||||
# [input[0, 1, 0], input[0, 1, 1], input[0, 0, 2]],
|
||||
# [input[1, 2, 0], input[1, 3, 1], input[1, 1, 2]],
|
||||
# ]
|
||||
#
|
||||
# minimum_slice1_slice2 = np.minimum(slice1, slice2)
|
||||
# minimum_slice3_slice4 = np.minimum(slice3, slice4)
|
||||
#
|
||||
# result = np.minimum(minimum_slice1_slice2, minimum_slice3_slice4)
|
||||
#
|
||||
# notice that all slices have the same shape as the result
|
||||
|
||||
class Mock:
|
||||
"""
|
||||
Class to track accumulation of the operation.
|
||||
"""
|
||||
|
||||
# list of indices that have been accumulated
|
||||
indices: Set[Tuple[int, ...]]
|
||||
|
||||
# initialize the mock with a starting index
|
||||
def __init__(self, index: Tuple[int, ...]):
|
||||
self.indices = {index}
|
||||
|
||||
# get the representation of the mock
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.indices}"
|
||||
|
||||
# combine the indices of the mock with another mock into a new mock
|
||||
def combine(self, other: "Mock") -> "Mock":
|
||||
result = deepcopy(self)
|
||||
for index in other.indices:
|
||||
result.indices.add(index)
|
||||
return result
|
||||
|
||||
# create the mock input
|
||||
#
|
||||
# [[[{(0, 0, 0)} {(0, 0, 1)} {(0, 0, 2)}]
|
||||
# [{(0, 1, 0)} {(0, 1, 1)} {(0, 1, 2)}]
|
||||
# [{(0, 2, 0)} {(0, 2, 1)} {(0, 2, 2)}]
|
||||
# [{(0, 3, 0)} {(0, 3, 1)} {(0, 3, 2)}]]
|
||||
#
|
||||
# [[{(1, 0, 0)} {(1, 0, 1)} {(1, 0, 2)}]
|
||||
# [{(1, 1, 0)} {(1, 1, 1)} {(1, 1, 2)}]
|
||||
# [{(1, 2, 0)} {(1, 2, 1)} {(1, 2, 2)}]
|
||||
# [{(1, 3, 0)} {(1, 3, 1)} {(1, 3, 2)}]]]
|
||||
|
||||
mock_input = []
|
||||
for index in np.ndindex(x.shape):
|
||||
mock_input.append(Mock(index))
|
||||
mock_input = np.array(mock_input).reshape(x.shape)
|
||||
|
||||
# use numpy reduction to compute the mock output
|
||||
#
|
||||
# [[{(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)}
|
||||
# {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)}
|
||||
# {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)}]
|
||||
# [{(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)}
|
||||
# {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)}
|
||||
# {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)}]]
|
||||
|
||||
mock_output = np.frompyfunc(lambda mock1, mock2: mock1.combine(mock2), 2, 1).reduce(
|
||||
mock_input,
|
||||
axis=tuple(axes),
|
||||
keepdims=keep_dims,
|
||||
)
|
||||
|
||||
# extract a sample mock from the mock output
|
||||
sample_mock = mock_output.flat[0]
|
||||
|
||||
# extract the indices of the sample mock
|
||||
sample_mock_indices = sample_mock.indices
|
||||
|
||||
# compute number of comparisons
|
||||
number_of_comparisons = len(sample_mock_indices)
|
||||
|
||||
# extract a sample index from the sample mock indices
|
||||
sample_mock_index = next(iter(sample_mock_indices))
|
||||
|
||||
# compute the number of indices
|
||||
number_of_indices = len(sample_mock_index)
|
||||
|
||||
# compute the shape of fancy indexing indices
|
||||
index_shape = resulting_type.shape
|
||||
|
||||
# compute the fancy indices to extract for comparison
|
||||
to_compare = []
|
||||
for _ in range(number_of_comparisons):
|
||||
indices = []
|
||||
for _ in range(number_of_indices):
|
||||
index = np.zeros(index_shape, dtype=np.int64) # type: ignore
|
||||
indices.append(index)
|
||||
to_compare.append(tuple(indices))
|
||||
for position in np.ndindex(mock_output.shape):
|
||||
mock_indices = list(mock_output[position].indices)
|
||||
for i in range(number_of_comparisons):
|
||||
for j in range(number_of_indices):
|
||||
to_compare[i][j][position] = mock_indices[i][j] # type: ignore
|
||||
|
||||
# to_compare will look like
|
||||
# [
|
||||
# # for the first slice
|
||||
# (
|
||||
# [[0, 0, 0], [1, 1, 1]], # i
|
||||
# [[2, 2, 1], [0, 2, 3]], # j
|
||||
# [[0, 1, 2], [0, 1, 2]], # k
|
||||
# ),
|
||||
# # for the second slice
|
||||
# (
|
||||
# [[0, 0, 0], [1, 1, 1]], # i
|
||||
# [[0, 0, 2], [3, 0, 2]], # j
|
||||
# [[0, 1, 2], [0, 1, 2]], # k
|
||||
# ),
|
||||
# # for the third slice
|
||||
# (
|
||||
# [[0, 0, 0], [1, 1, 1]], # i
|
||||
# [[3, 3, 3], [1, 1, 0]], # j
|
||||
# [[0, 1, 2], [0, 1, 2]], # k
|
||||
# ),
|
||||
# # for the fourth slice
|
||||
# (
|
||||
# [[0, 0, 0], [1, 1, 1]], # i
|
||||
# [[1, 1, 0], [2, 3, 1]], # j
|
||||
# [[0, 1, 2], [0, 1, 2]], # k
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# find the type of the slices
|
||||
slices_type = ctx.tensor(ctx.element_typeof(x), shape=resulting_type.shape)
|
||||
|
||||
# extract the slices
|
||||
slices = []
|
||||
for index in to_compare:
|
||||
slices.append(ctx.index(slices_type, x, index)) # type: ignore
|
||||
|
||||
# while there are more than 1 slices
|
||||
while len(slices) > 1:
|
||||
# pop the last two slices
|
||||
a = slices.pop()
|
||||
b = slices.pop()
|
||||
|
||||
# compare the last two slices
|
||||
if operation == "min":
|
||||
c = ctx.minimum(resulting_type, a, b)
|
||||
else:
|
||||
c = ctx.maximum(resulting_type, a, b)
|
||||
|
||||
# we need to set the original bit width manually
|
||||
# as minimum/maximum doesn't constraint their output bit width.
|
||||
c.set_original_bit_width(x.original_bit_width)
|
||||
|
||||
# insert the slice back at the beginning of the slice queue
|
||||
slices.insert(0, c)
|
||||
|
||||
# return the result
|
||||
return slices[0]
|
||||
|
||||
|
||||
def reduce(
|
||||
ctx: Context,
|
||||
resulting_type: ConversionType,
|
||||
values: Conversion,
|
||||
*,
|
||||
operation: str,
|
||||
) -> Conversion:
|
||||
"""
|
||||
Reduce a vector of values to its min/max value.
|
||||
"""
|
||||
|
||||
# make sure the operation is valid
|
||||
assert operation in {"min", "max"}
|
||||
|
||||
# make sure the value is valid
|
||||
assert values.is_tensor
|
||||
assert len(values.shape) == 1
|
||||
assert values.is_encrypted
|
||||
|
||||
# make sure the resulting type is valid
|
||||
assert resulting_type.is_scalar
|
||||
assert resulting_type.is_encrypted
|
||||
|
||||
# let's say the vector was [1, 4, 2, 3, 0]
|
||||
# and we're computing np.min(vector)
|
||||
|
||||
# find the element type of the vector = fhe.uint3
|
||||
values_element_type = ctx.element_typeof(values)
|
||||
|
||||
# find the middle of the vector = 2
|
||||
middle = values.shape[0] // 2
|
||||
|
||||
# we'll be splitting the array into two halves
|
||||
# [1, 4] and [2, 3] in our case
|
||||
# then we'll compute np.minimum(first_half, second_half)
|
||||
# [1, 3] in our case
|
||||
# if the reduction is a scalar, we'll return it
|
||||
# otherwise, we'll reduce recursively until we obtain a scalar
|
||||
# 1 in our case
|
||||
|
||||
# find the half type of the vector which is
|
||||
# fhe.tensor[fhe.uint3, 2] in the first iteration
|
||||
# fhe.uint3 in the last iteration
|
||||
half_type = (
|
||||
ctx.tensor(values_element_type, shape=(middle,)) if middle != 1 else values_element_type
|
||||
)
|
||||
|
||||
# find the accumulated type of the vector which is
|
||||
# fhe.tensor[fhe.uint3, 2] in the first iteration
|
||||
# fhe.uint3 in the last iteration
|
||||
accumulated_type = (
|
||||
ctx.tensor(resulting_type, shape=(middle,)) if middle != 1 else resulting_type
|
||||
)
|
||||
|
||||
# if there is only one element in each half (e.g., vector = [1, 3], halfs = [1], [3])
|
||||
if middle == 1:
|
||||
# extract the first element in the vector
|
||||
first_half = ctx.index(half_type, values, index=[0])
|
||||
# extract the second element in the vector
|
||||
second_half = ctx.index(half_type, values, index=[1])
|
||||
else:
|
||||
# extract the elements from 0 to middle as the first half
|
||||
first_half = ctx.index(half_type, values, index=[slice(0, middle)])
|
||||
# extract the elements from middle to 2*middle as the second half
|
||||
second_half = ctx.index(half_type, values, index=[slice(middle, 2 * middle)])
|
||||
|
||||
# compare halfs
|
||||
if operation == "min":
|
||||
# [1, 3] in the first iteration
|
||||
# 1 in the last iteration
|
||||
reduced = ctx.minimum(accumulated_type, first_half, second_half)
|
||||
else:
|
||||
reduced = ctx.maximum(accumulated_type, first_half, second_half)
|
||||
|
||||
# set the original bit width of the reduced so the following operation work as intended
|
||||
# this is required here since ctx.minimum and ctx.maximum does not constraint output bit width
|
||||
reduced.set_original_bit_width(values.original_bit_width)
|
||||
|
||||
result = (
|
||||
# if reduced value is a scalar, we end the recursion
|
||||
reduced
|
||||
if reduced.is_scalar
|
||||
# otherwise, we reduce the result of comparison of halfs
|
||||
else reduce(ctx, resulting_type, reduced, operation=operation)
|
||||
)
|
||||
|
||||
# if we have one more element that wasn't in the halfs
|
||||
if values.shape[0] % 2 == 1:
|
||||
# we extract it
|
||||
last_value = ctx.index(values_element_type, values, index=[-1])
|
||||
# and compare it with the result we obtained from the halfs
|
||||
result = (
|
||||
ctx.minimum(resulting_type, result, last_value)
|
||||
if operation == "min"
|
||||
else ctx.maximum(resulting_type, result, last_value)
|
||||
)
|
||||
# again, we need to set the original bit width
|
||||
result.set_original_bit_width(values.original_bit_width)
|
||||
|
||||
# here is the visualization of the algorithm
|
||||
#
|
||||
# [ 1, 4, 2, 3, 0 ]
|
||||
#
|
||||
# [1, 4][2, 3],0
|
||||
# \ / |
|
||||
# [1, 3] |
|
||||
# \ /
|
||||
# 1 /
|
||||
# \ /
|
||||
# 0
|
||||
#
|
||||
# it has O(log(n)) - 1 number of tensor comparisons (of sizes n/2, n/4, ...)
|
||||
# and up to 1 + O(log(n)) number of scalar comparisons (depending on the oddity of n/2, n/4, ..)
|
||||
|
||||
return result
|
||||
@@ -351,7 +351,7 @@ class AdditionalConstraints:
|
||||
node.properties["strategy"] = strategy
|
||||
break
|
||||
|
||||
def min_max(self, node: Node, preds: List[Node]):
|
||||
def minimum_maximum(self, node: Node, preds: List[Node]):
|
||||
assert len(preds) == 2
|
||||
|
||||
x = preds[0]
|
||||
@@ -392,69 +392,83 @@ class AdditionalConstraints:
|
||||
node.properties["strategy"] = strategy
|
||||
break
|
||||
|
||||
def min_max(self, node: Node, preds: List[Node]):
|
||||
assert len(preds) == 1
|
||||
self.minimum_maximum(node, [preds[0], preds[0]])
|
||||
|
||||
# ==========
|
||||
# Operations
|
||||
# ==========
|
||||
|
||||
add = { # noqa: RUF012
|
||||
add = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
array = { # noqa: RUF012
|
||||
amax = {
|
||||
min_max,
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
assign_dynamic = { # noqa: RUF012
|
||||
amin = {
|
||||
min_max,
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
assign_static = { # noqa: RUF012
|
||||
array = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
bitwise_and = { # noqa: RUF012
|
||||
assign_dynamic = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
assign_static = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
bitwise_and = {
|
||||
all_inputs_are_encrypted: {
|
||||
bitwise,
|
||||
},
|
||||
}
|
||||
|
||||
bitwise_or = { # noqa: RUF012
|
||||
bitwise_or = {
|
||||
all_inputs_are_encrypted: {
|
||||
bitwise,
|
||||
},
|
||||
}
|
||||
|
||||
bitwise_xor = { # noqa: RUF012
|
||||
bitwise_xor = {
|
||||
all_inputs_are_encrypted: {
|
||||
bitwise,
|
||||
},
|
||||
}
|
||||
|
||||
broadcast_to = { # noqa: RUF012
|
||||
broadcast_to = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
concatenate = { # noqa: RUF012
|
||||
concatenate = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
conv1d = { # noqa: RUF012
|
||||
conv1d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
conv2d = { # noqa: RUF012
|
||||
conv2d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
conv3d = { # noqa: RUF012
|
||||
conv3d = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
copy = { # noqa: RUF012
|
||||
copy = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
dot = { # noqa: RUF012
|
||||
dot = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
@@ -464,51 +478,51 @@ class AdditionalConstraints:
|
||||
},
|
||||
}
|
||||
|
||||
equal = { # noqa: RUF012
|
||||
equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
expand_dims = { # noqa: RUF012
|
||||
expand_dims = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
greater = { # noqa: RUF012
|
||||
greater = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
greater_equal = { # noqa: RUF012
|
||||
greater_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
index_static = { # noqa: RUF012
|
||||
index_static = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
left_shift = { # noqa: RUF012
|
||||
left_shift = {
|
||||
all_inputs_are_encrypted: {
|
||||
bitwise,
|
||||
},
|
||||
}
|
||||
|
||||
less = { # noqa: RUF012
|
||||
less = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
less_equal = { # noqa: RUF012
|
||||
less_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
matmul = { # noqa: RUF012
|
||||
matmul = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
@@ -518,34 +532,44 @@ class AdditionalConstraints:
|
||||
},
|
||||
}
|
||||
|
||||
maximum = { # noqa: RUF012
|
||||
max = {
|
||||
min_max,
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
maximum = {
|
||||
all_inputs_are_encrypted: {
|
||||
min_max,
|
||||
minimum_maximum,
|
||||
},
|
||||
}
|
||||
|
||||
maxpool1d = { # noqa: RUF012
|
||||
maxpool1d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
maxpool2d = { # noqa: RUF012
|
||||
maxpool2d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
maxpool3d = { # noqa: RUF012
|
||||
maxpool3d = {
|
||||
inputs_and_output_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
}
|
||||
|
||||
minimum = { # noqa: RUF012
|
||||
min = {
|
||||
min_max,
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
minimum = {
|
||||
all_inputs_are_encrypted: {
|
||||
min_max,
|
||||
minimum_maximum,
|
||||
},
|
||||
}
|
||||
|
||||
multiply = { # noqa: RUF012
|
||||
multiply = {
|
||||
all_inputs_are_encrypted: {
|
||||
inputs_share_precision,
|
||||
inputs_require_one_more_bit,
|
||||
@@ -555,48 +579,48 @@ class AdditionalConstraints:
|
||||
},
|
||||
}
|
||||
|
||||
negative = { # noqa: RUF012
|
||||
negative = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
not_equal = { # noqa: RUF012
|
||||
not_equal = {
|
||||
all_inputs_are_encrypted: {
|
||||
comparison,
|
||||
},
|
||||
}
|
||||
|
||||
reshape = { # noqa: RUF012
|
||||
reshape = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
right_shift = { # noqa: RUF012
|
||||
right_shift = {
|
||||
all_inputs_are_encrypted: {
|
||||
bitwise,
|
||||
},
|
||||
}
|
||||
|
||||
round_bit_pattern = { # noqa: RUF012
|
||||
round_bit_pattern = {
|
||||
has_overflow_protection: {
|
||||
inputs_and_output_share_precision,
|
||||
},
|
||||
}
|
||||
|
||||
subtract = { # noqa: RUF012
|
||||
subtract = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
sum = { # noqa: RUF012
|
||||
sum = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
squeeze = { # noqa: RUF012
|
||||
squeeze = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
transpose = { # noqa: RUF012
|
||||
transpose = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
truncate_bit_pattern = { # noqa: RUF012
|
||||
truncate_bit_pattern = {
|
||||
inputs_and_output_share_precision,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Declaration of various functions and constants related to MLIR conversion.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
@@ -26,7 +26,7 @@ from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
# pylint: enable=import-error
|
||||
# pylint: enable=import-error,no-name-in-module
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
|
||||
@@ -705,18 +705,27 @@ class Graph:
|
||||
bounds of each node in the `Graph`
|
||||
"""
|
||||
|
||||
for node in self.graph.nodes():
|
||||
for node in self.query_nodes(ordered=True):
|
||||
if node in bounds:
|
||||
min_bound = bounds[node]["min"]
|
||||
max_bound = bounds[node]["max"]
|
||||
|
||||
node.bounds = (min_bound, max_bound)
|
||||
node.bounds = (min_bound, max_bound) # type: ignore
|
||||
|
||||
new_value = deepcopy(node.output)
|
||||
|
||||
if isinstance(min_bound, (np.integer, int)):
|
||||
assert isinstance(new_value.dtype, Integer)
|
||||
new_value.dtype.update_to_represent(np.array([min_bound, max_bound]))
|
||||
|
||||
if node.operation == Operation.Generic and node.properties["name"] in {
|
||||
"amin",
|
||||
"amax",
|
||||
"min",
|
||||
"max",
|
||||
}:
|
||||
assert isinstance(node.inputs[0].dtype, Integer)
|
||||
new_value.dtype.is_signed = node.inputs[0].dtype.is_signed
|
||||
else:
|
||||
new_value.dtype = {
|
||||
np.bool_: UnsignedInteger(1),
|
||||
|
||||
@@ -281,7 +281,9 @@ class Tracer:
|
||||
np.logical_or,
|
||||
np.logical_xor,
|
||||
np.matmul,
|
||||
np.max,
|
||||
np.maximum,
|
||||
np.min,
|
||||
np.minimum,
|
||||
np.mod,
|
||||
np.multiply,
|
||||
@@ -331,6 +333,14 @@ class Tracer:
|
||||
np.expand_dims: {
|
||||
"axis",
|
||||
},
|
||||
np.max: {
|
||||
"axis",
|
||||
"keepdims",
|
||||
},
|
||||
np.min: {
|
||||
"axis",
|
||||
"keepdims",
|
||||
},
|
||||
np.ones_like: {
|
||||
"dtype",
|
||||
},
|
||||
@@ -471,6 +481,12 @@ class Tracer:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["shape"] = args[1]
|
||||
elif func in {np.min, np.max}:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
for i, keyword in enumerate(["axis", "out", "keepdims", "initial", "where"]):
|
||||
position = i + 1
|
||||
if len(args) > position:
|
||||
kwargs[keyword] = args[position]
|
||||
elif func is np.reshape:
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
|
||||
@@ -19,7 +19,7 @@ def ast_iterator(root):
|
||||
while nodes:
|
||||
current_node = nodes.pop(0)
|
||||
yield current_node
|
||||
if hasattr(current_node, "children"):
|
||||
if hasattr(current_node, "children") and current_node.children is not None:
|
||||
nodes += current_node.children
|
||||
|
||||
|
||||
|
||||
@@ -318,6 +318,10 @@ class Helpers:
|
||||
if i == retries - 1:
|
||||
message = f"""
|
||||
|
||||
Sample
|
||||
===============
|
||||
{sample}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Tests of execution of minimum and maximum operations.
|
||||
Tests of execution of min and max operations.
|
||||
"""
|
||||
|
||||
import random
|
||||
@@ -11,219 +11,145 @@ from concrete import fhe
|
||||
from concrete.fhe.dtypes import Integer
|
||||
from concrete.fhe.values import ValueDescription
|
||||
|
||||
cases = [
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"minimum_optimized_x",
|
||||
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
4,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"minimum_optimized_x",
|
||||
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
4,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(2,),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"maximum_optimized_y",
|
||||
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
3,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(),
|
||||
(2, 3),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"maximum_optimized_y",
|
||||
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
3,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
]
|
||||
cases += [
|
||||
[
|
||||
# operation
|
||||
operation,
|
||||
# bit widths
|
||||
1,
|
||||
1,
|
||||
# signednesses
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
]
|
||||
for lhs_is_signed in [False, True]
|
||||
for rhs_is_signed in [False, True]
|
||||
for operation in [
|
||||
(
|
||||
"maximum",
|
||||
lambda x, y: np.maximum(x, y),
|
||||
),
|
||||
]
|
||||
]
|
||||
cases = [
|
||||
[
|
||||
# operation
|
||||
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
|
||||
# bit widths
|
||||
7,
|
||||
7,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
|
||||
# bit widths
|
||||
7,
|
||||
7,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
]
|
||||
for lhs_bit_width in range(1, 5):
|
||||
for rhs_bit_width in range(1, 5):
|
||||
strategies = []
|
||||
if lhs_bit_width <= 3 and rhs_bit_width <= 3:
|
||||
strategies += [
|
||||
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
|
||||
fhe.MinMaxStrategy.THREE_TLU_CASTED,
|
||||
]
|
||||
else:
|
||||
strategies += [
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
]
|
||||
|
||||
for lhs_is_signed in [False, True]:
|
||||
for rhs_is_signed in [False, True]:
|
||||
cases += [
|
||||
[
|
||||
# operation
|
||||
operation,
|
||||
# bit widths
|
||||
lhs_bit_width,
|
||||
rhs_bit_width,
|
||||
# signednesses
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
# shapes
|
||||
random.choice([(), (2,), (3, 2)]),
|
||||
random.choice([(), (2,), (3, 2)]),
|
||||
# strategy
|
||||
strategy,
|
||||
]
|
||||
for operation in [
|
||||
("minimum", lambda x, y: np.minimum(x, y)),
|
||||
("maximum", lambda x, y: np.maximum(x, y)),
|
||||
]
|
||||
for strategy in strategies
|
||||
]
|
||||
cases = []
|
||||
for operation in [("max", lambda x: np.max(x)), ("min", lambda x: np.min(x))]:
|
||||
for bit_width in range(1, 5):
|
||||
for is_signed in [False, True]:
|
||||
for shape in [(), (4,), (3, 3)]:
|
||||
for keepdims in [False, True]:
|
||||
for strategy in [
|
||||
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
|
||||
fhe.MinMaxStrategy.THREE_TLU_CASTED,
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
]:
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
None,
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
for axis in range(len(shape)):
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
axis,
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
-1,
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
if len(shape) == 2:
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
(0, 1),
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
-2,
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
if len(shape) == 3:
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
(0, 1),
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
(0, 2),
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
cases.append(
|
||||
[
|
||||
operation,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
(1, 2),
|
||||
keepdims,
|
||||
strategy,
|
||||
],
|
||||
)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,"
|
||||
"lhs_bit_width,rhs_bit_width,"
|
||||
"lhs_is_signed,rhs_is_signed,"
|
||||
"lhs_shape,rhs_shape,"
|
||||
"strategy",
|
||||
cases,
|
||||
"operation,bit_width,is_signed,shape,axis,keepdims,strategy",
|
||||
random.sample(cases, 100),
|
||||
)
|
||||
def test_minimum_maximum(
|
||||
def test_min_max(
|
||||
operation,
|
||||
lhs_bit_width,
|
||||
rhs_bit_width,
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
lhs_shape,
|
||||
rhs_shape,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
axis,
|
||||
keepdims,
|
||||
strategy,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test comparison operations between encrypted integers.
|
||||
Test np.min/np.max on encrypted values.
|
||||
"""
|
||||
|
||||
name, function = operation
|
||||
|
||||
lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width)
|
||||
rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width)
|
||||
|
||||
lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True)
|
||||
rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True)
|
||||
dtype = Integer(is_signed=is_signed, bit_width=bit_width)
|
||||
description = ValueDescription(dtype, shape=shape, is_encrypted=True)
|
||||
|
||||
print()
|
||||
print()
|
||||
print(
|
||||
f"{name}({lhs_description}, {rhs_description})"
|
||||
f"np.{name}({description}, axis={axis}, keepdims={keepdims})"
|
||||
+ (f" {{{strategy}}}" if strategy is not None else "")
|
||||
)
|
||||
print()
|
||||
print()
|
||||
|
||||
parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"}
|
||||
parameter_encryption_statuses = {"x": "encrypted"}
|
||||
configuration = helpers.configuration()
|
||||
|
||||
if strategy is not None:
|
||||
@@ -231,59 +157,17 @@ def test_minimum_maximum(
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
|
||||
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
|
||||
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
samples = [
|
||||
[
|
||||
np.zeros(lhs_shape, dtype=np.int64),
|
||||
np.zeros(rhs_shape, dtype=np.int64),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(),
|
||||
],
|
||||
[
|
||||
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
|
||||
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
|
||||
],
|
||||
np.zeros(shape, dtype=np.int64),
|
||||
np.ones(shape, dtype=np.int64) * dtype.min(),
|
||||
np.ones(shape, dtype=np.int64) * dtype.max(),
|
||||
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
|
||||
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
|
||||
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
|
||||
]
|
||||
for sample in samples:
|
||||
helpers.check_execution(circuit, function, sample, retries=5)
|
||||
|
||||
|
||||
def test_internal_signed_tlu_padding(helpers):
|
||||
"""Test that the signed input LUT is correctly padded in the case of substraction trick."""
|
||||
|
||||
inputset = [(i, j) for i in [0, 1] for j in [0, 1]]
|
||||
|
||||
@fhe.compiler({"a": "encrypted", "b": "encrypted"})
|
||||
def min2(a, b):
|
||||
min_12 = np.minimum(a, b)
|
||||
return (min_12, a + 3, b + 3)
|
||||
|
||||
c = min2.compile(inputset, helpers.configuration())
|
||||
min_0_1, _, _ = c.encrypt_run_decrypt(0, 1)
|
||||
|
||||
assert min_0_1 == 0
|
||||
|
||||
# Some extra checks to verify that the test is relevant (substraction trick).
|
||||
assert c.mlir.count("to_signed") == 2 # check substraction trick is used
|
||||
assert c.mlir.count("sub_eint") == 1 # check substraction trick is used
|
||||
assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end
|
||||
assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Tests of execution of minimum and maximum operations.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete import fhe
|
||||
from concrete.fhe.dtypes import Integer
|
||||
from concrete.fhe.values import ValueDescription
|
||||
|
||||
cases = [
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"minimum_optimized_x",
|
||||
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
4,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"minimum_optimized_x",
|
||||
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
4,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(2,),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"maximum_optimized_y",
|
||||
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
3,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(),
|
||||
(2, 3),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
(
|
||||
"maximum_optimized_y",
|
||||
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
|
||||
),
|
||||
# bit widths
|
||||
4,
|
||||
3,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
]
|
||||
cases += [
|
||||
[
|
||||
# operation
|
||||
operation,
|
||||
# bit widths
|
||||
1,
|
||||
1,
|
||||
# signednesses
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
]
|
||||
for lhs_is_signed in [False, True]
|
||||
for rhs_is_signed in [False, True]
|
||||
for operation in [
|
||||
(
|
||||
"maximum",
|
||||
lambda x, y: np.maximum(x, y),
|
||||
),
|
||||
]
|
||||
]
|
||||
cases = [
|
||||
[
|
||||
# operation
|
||||
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
|
||||
# bit widths
|
||||
7,
|
||||
7,
|
||||
# signednesses
|
||||
True,
|
||||
False,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
[
|
||||
# operation
|
||||
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
|
||||
# bit widths
|
||||
7,
|
||||
7,
|
||||
# signednesses
|
||||
False,
|
||||
True,
|
||||
# shapes
|
||||
(),
|
||||
(),
|
||||
# strategy
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
],
|
||||
]
|
||||
for lhs_bit_width in range(1, 5):
|
||||
for rhs_bit_width in range(1, 5):
|
||||
strategies = []
|
||||
if lhs_bit_width <= 3 and rhs_bit_width <= 3:
|
||||
strategies += [
|
||||
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
|
||||
fhe.MinMaxStrategy.THREE_TLU_CASTED,
|
||||
]
|
||||
else:
|
||||
strategies += [
|
||||
fhe.MinMaxStrategy.CHUNKED,
|
||||
]
|
||||
|
||||
for lhs_is_signed in [False, True]:
|
||||
for rhs_is_signed in [False, True]:
|
||||
cases += [
|
||||
[
|
||||
# operation
|
||||
operation,
|
||||
# bit widths
|
||||
lhs_bit_width,
|
||||
rhs_bit_width,
|
||||
# signednesses
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
# shapes
|
||||
random.choice([(), (2,), (3, 2)]),
|
||||
random.choice([(), (2,), (3, 2)]),
|
||||
# strategy
|
||||
strategy,
|
||||
]
|
||||
for operation in [
|
||||
("minimum", lambda x, y: np.minimum(x, y)),
|
||||
("maximum", lambda x, y: np.maximum(x, y)),
|
||||
]
|
||||
for strategy in strategies
|
||||
]
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,"
|
||||
"lhs_bit_width,rhs_bit_width,"
|
||||
"lhs_is_signed,rhs_is_signed,"
|
||||
"lhs_shape,rhs_shape,"
|
||||
"strategy",
|
||||
cases,
|
||||
)
|
||||
def test_minimum_maximum(
|
||||
operation,
|
||||
lhs_bit_width,
|
||||
rhs_bit_width,
|
||||
lhs_is_signed,
|
||||
rhs_is_signed,
|
||||
lhs_shape,
|
||||
rhs_shape,
|
||||
strategy,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test comparison operations between encrypted integers.
|
||||
"""
|
||||
|
||||
name, function = operation
|
||||
|
||||
lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width)
|
||||
rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width)
|
||||
|
||||
lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True)
|
||||
rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True)
|
||||
|
||||
print()
|
||||
print()
|
||||
print(
|
||||
f"{name}({lhs_description}, {rhs_description})"
|
||||
+ (f" {{{strategy}}}" if strategy is not None else "")
|
||||
)
|
||||
print()
|
||||
print()
|
||||
|
||||
parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"}
|
||||
configuration = helpers.configuration()
|
||||
|
||||
if strategy is not None:
|
||||
configuration = configuration.fork(min_max_strategy_preference=[strategy])
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
|
||||
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
samples = [
|
||||
[
|
||||
np.zeros(lhs_shape, dtype=np.int64),
|
||||
np.zeros(rhs_shape, dtype=np.int64),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
|
||||
],
|
||||
[
|
||||
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
|
||||
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(),
|
||||
],
|
||||
[
|
||||
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
|
||||
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
|
||||
],
|
||||
]
|
||||
for sample in samples:
|
||||
helpers.check_execution(circuit, function, sample, retries=5)
|
||||
|
||||
|
||||
def test_internal_signed_tlu_padding(helpers):
|
||||
"""Test that the signed input LUT is correctly padded in the case of substraction trick."""
|
||||
|
||||
inputset = [(i, j) for i in [0, 1] for j in [0, 1]]
|
||||
|
||||
@fhe.compiler({"a": "encrypted", "b": "encrypted"})
|
||||
def min2(a, b):
|
||||
min_12 = np.minimum(a, b)
|
||||
return (min_12, a + 3, b + 3)
|
||||
|
||||
c = min2.compile(inputset, helpers.configuration())
|
||||
min_0_1, _, _ = c.encrypt_run_decrypt(0, 1)
|
||||
|
||||
assert min_0_1 == 0
|
||||
|
||||
# Some extra checks to verify that the test is relevant (substraction trick).
|
||||
assert c.mlir.count("to_signed") == 2 # check substraction trick is used
|
||||
assert c.mlir.count("sub_eint") == 1 # check substraction trick is used
|
||||
assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end
|
||||
assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle
|
||||
Reference in New Issue
Block a user