feat(frontend): add support for np.min and np.max

This commit is contained in:
Umut
2024-09-05 16:14:52 +03:00
parent 532000f8be
commit d3dfdcd699
18 changed files with 1021 additions and 322 deletions

View File

@@ -113,9 +113,11 @@ jobs:
find .testenv/lib/python3.10/site-packages -not \( -path .testenv/lib/python3.10/site-packages/concrete -prune \) -name 'lib*omp5.dylib' -or -name 'lib*omp.dylib' | xargs -n 1 ln -f -s $(pwd)/.testenv/lib/python3.10/site-packages/concrete/.dylibs/libomp.dylib
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/examples ./examples
cp -R $GITHUB_WORKSPACE/frontends/concrete-python/tests ./tests
cp $GITHUB_WORKSPACE/frontends/concrete-python/Makefile .
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest
KEY_CACHE_DIRECTORY=./KeySetCache PYTEST_MARKERS="not dataflow and not graphviz" make pytest-macos
- name: Cleanup host
if: success() || failure()

View File

@@ -6,30 +6,6 @@ This document introduces several common techniques for optimizing code to fit Fu
All code snippets provided here are temporary workarounds. In future versions of Concrete, some functions described here could be directly available in a more generic and efficient form. These code snippets are coming from support answers in our [community forum](https://community.zama.ai)
{% endhint %}
## Minimum/Maximum for multiple values
Concrete supports `np.minimum`/`np.maximum` natively, but not `np.min`/`np.max` yet. To achieve the functionality, you can do a series of `np.minimum`/`np.maximum`s:
```python
import numpy as np
from concrete import fhe
@fhe.compiler({"args": "encrypted"})
def fhe_min(args):
remaining = list(args)
while len(remaining) > 1:
a = remaining.pop()
b = remaining.pop()
remaining.insert(0, np.minimum(a, b))
return remaining[0]
inputset = [np.random.randint(0, 16, size=5) for _ in range(50)]
circuit = fhe_min.compile(inputset, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED)
x1, x2, x3, x4, x5 = np.random.randint(0, 16, size=5)
assert circuit.encrypt_run_decrypt([x1, x2, x3, x4, x5]) == min(x1, x2, x3, x4, x5)
```
## Retrieving a value within an encrypted array with an encrypted index
This example demonstrates how to retrieve a value from an array using an encrypted index. The method creates a "selection" array filled with `0`s except for the requested index, which will be `1`. It then sums the products of all array values with this selection array:

View File

@@ -110,7 +110,9 @@ Some operations are not supported between two encrypted values. If attempted, a
* [np.logical\_or](https://numpy.org/doc/stable/reference/generated/numpy.logical\_or.html)
* [np.logical\_xor](https://numpy.org/doc/stable/reference/generated/numpy.logical\_xor.html)
* [np.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html)
* [np.max](https://numpy.org/doc/stable/reference/generated/numpy.max.html)
* [np.maximum](https://numpy.org/doc/stable/reference/generated/numpy.maximum.html)
* [np.min](https://numpy.org/doc/stable/reference/generated/numpy.min.html)
* [np.minimum](https://numpy.org/doc/stable/reference/generated/numpy.minimum.html)
* [np.multiply](https://numpy.org/doc/stable/reference/generated/numpy.multiply.html)
* [np.negative](https://numpy.org/doc/stable/reference/generated/numpy.negative.html)

View File

@@ -16,7 +16,7 @@ ignore = [
"**/__init__.py" = ["F401"]
"concrete/fhe/compilation/configuration.py" = ["ARG002"]
"concrete/fhe/mlir/processors/all.py" = ["F401"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002", "RUF012"]
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
"concrete/**" = ["RUF010"]
"examples/**" = ["PLR2004", "RUF010"]

View File

@@ -1,21 +1,21 @@
PYTHON=python
PIP=$(PYTHON) -m pip
COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/
OS=undefined
COVERAGE_OPT=""
ifeq ($(shell uname), Linux)
OS=linux
COVERAGE_OPT="--cov=concrete.fhe --cov-fail-under=100 --cov-report=term-missing:skip-covered"
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
else ifeq ($(shell uname), Darwin)
OS=darwin
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.dylib
endif
COMPILER_BUILD_DIRECTORY ?= $(PWD)/../../compilers/concrete-compiler/compiler/build
BINDINGS_DIRECTORY=${COMPILER_BUILD_DIRECTORY}/tools/concretelang/python_packages/concretelang_core/
RUNTIME_LIBRARY?=${COMPILER_BUILD_DIRECTORY}/lib/libConcretelangRuntime.so
TFHERS_UTILS_DIRECTORY ?= $(PWD)/tests/tfhers-utils/
CONCRETE_VERSION?="" # empty mean latest
# E.g. to use a previous version: `make CONCRETE_VERSION="<2.7.0" venv`
# E.g. to use a nightly: `make CONCRETE_VERSION="==2.7.0dev20240801`
@@ -76,6 +76,11 @@ pytest-default: tfhers-utils
--key-cache "${KEY_CACHE_DIRECTORY}" \
-m "${PYTEST_MARKERS}"
pytest-macos:
pytest tests -svv -n auto \
--key-cache "${KEY_CACHE_DIRECTORY}" \
-m "${PYTEST_MARKERS}"
pytest-single: tfhers-utils
eval $(shell make silent_cp_activate)
# test single precision, mono params

View File

@@ -2197,7 +2197,11 @@ class Context:
if cached_conversion is None:
cached_conversion = Conversion(
self.converting,
arith.ConstantOp(resulting_type, attribute, loc=self.location()),
arith.ConstantOp( # pylint: disable=too-many-function-args
resulting_type,
attribute,
loc=self.location(),
),
)
try:
@@ -2813,6 +2817,25 @@ class Context:
return self.to_signedness(result, of=resulting_type)
def min_max(
self,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
):
# This import needs to happen here to avoid circular imports.
# pylint: disable=import-outside-toplevel
from .operations.min_max import min_max
return min_max(self, resulting_type, x, axes, keep_dims, operation=operation)
# pylint: enable=import-outside-toplevel
def minimum(
self,
resulting_type: ConversionType,
@@ -3988,7 +4011,7 @@ class Context:
def get_partition_name(self, partition: tfhers.CryptoParams) -> str:
if partition not in self.tfhers_partition.keys():
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2**32)}" # noqa: S311
self.tfhers_partition[partition] = f"tfhers_{randint(0, 2 ** 32)}" # noqa: S311
return self.tfhers_partition[partition]
def change_partition(

View File

@@ -2,7 +2,7 @@
Declaration of `ConversionType` and `Conversion` classes.
"""
# pylint: disable=import-error,
# pylint: disable=import-error,no-name-in-module
import re
from typing import Optional, Tuple
@@ -12,7 +12,7 @@ from mlir.ir import Type as MlirType
from ..representation import Node
# pylint: enable=import-error
# pylint: enable=import-error,no-name-in-module
SCALAR_INT_SEARCH_REGEX = re.compile(r"^i([0-9]+)$")

View File

@@ -326,6 +326,12 @@ class Converter:
assert len(preds) == 2
return ctx.add(ctx.typeof(node), preds[0], preds[1])
def amax(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
return self.max(ctx, node, preds)
def amin(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
return self.min(ctx, node, preds)
def array(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) > 0
return ctx.array(ctx.typeof(node), elements=preds)
@@ -530,6 +536,20 @@ class Converter:
assert len(preds) == 2
return ctx.matmul(ctx.typeof(node), preds[0], preds[1])
def max(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="max",
)
return self.tlu(ctx, node, preds)
def maximum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2
@@ -556,6 +576,20 @@ class Converter:
ctx.error({node: "3-dimensional maxpooling is not supported at the moment"})
assert False, "unreachable" # pragma: no cover
def min(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
if all(pred.is_encrypted for pred in preds):
return ctx.min_max(
ctx.typeof(node),
preds[0],
axes=node.properties["kwargs"].get("axis", ()),
keep_dims=node.properties["kwargs"].get("keepdims", False),
operation="min",
)
return self.tlu(ctx, node, preds)
def minimum(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 2

View File

@@ -490,6 +490,7 @@ def fancy_indexing(
resulting_type,
x.result,
indices.result,
original_bit_width=x.original_bit_width,
)
@@ -640,6 +641,7 @@ def indexing(
MlirDenseI64ArrayAttr.get(static_offsets),
MlirDenseI64ArrayAttr.get(static_sizes),
MlirDenseI64ArrayAttr.get(static_strides),
original_bit_width=x.original_bit_width,
)
reassociaton = []
@@ -669,4 +671,5 @@ def indexing(
for indices in reassociaton
],
),
original_bit_width=x.original_bit_width,
)

View File

@@ -0,0 +1,428 @@
"""
Conversion of min and max operations.
"""
from copy import deepcopy
from typing import Sequence, Set, Tuple, Union
import numpy as np
from ..context import Context
from ..conversion import Conversion, ConversionType
def min_max(
ctx: Context,
resulting_type: ConversionType,
x: Conversion,
axes: Union[int, np.integer, Sequence[Union[int, np.integer]]] = (),
keep_dims: bool = False,
*,
operation: str,
) -> Conversion:
"""
Convert min or max operation.
Args:
ctx (Context):
conversion context
resulting_type (ConversionType):
resulting type of the operation
x (Conversion):
input of the operation
axes (Union[int, np.integer, Sequence[Union[int, np.integer]]], default = ()):
axes to reduce over
keep_dims (bool, default = False):
whether to keep the reduced axes
operation (str):
"min" or "max"
Returns:
Conversion:
np.min or np.max on x depending on operation
"""
# if the input is clear
if x.is_clear:
# raise error as computing min/max of clear values is not supported
highlights = {
x.origin: "value is clear",
ctx.converting: f"but computing {operation} of clear values is not supported",
}
ctx.error(highlights)
# if the value is scalar
if x.is_scalar:
# return it as it's the min/max
return x
# if axes is not specified, use all axes
if axes is None:
axes = []
# compute the list of unique axes to reduce
# empty list means reduce all axes
axes = list(
set(
# if axes was a single integer, only it will be reduced
[int(axes)]
if isinstance(axes, (int, np.integer))
# if axes was a sequence, every axis in it will be reduced
else [int(axis) for axis in axes]
)
)
# sanitize negative axis
# `axis=-1` is the same as `axis=(-1 + x.ndim))`
input_dimensions = len(x.shape)
for i, axis in enumerate(axes):
if axis < 0:
axes[i] += input_dimensions
assert 0 <= axes[i] < input_dimensions
# if all axes are reduced
if len(axes) == 0 or len(axes) == len(x.shape):
# we flatten the input to use the `reduce` implementation
x = ctx.flatten(x)
# if the input is a vector
if len(x.shape) == 1:
# we reduce the vector to its min/max value
result = reduce(ctx, ctx.element_typeof(resulting_type), x, operation=operation)
# if the user wants to keep the reduced dimensions
if keep_dims:
# reshape the result into the resulting shape
result = ctx.reshape(result, shape=resulting_type.shape)
# return the result
return result
# if the reduce implementation is not used
# mock implementation will be used instead
#
# the idea is to let numpy compute the indices that will be compared to obtain the result
#
# for example, if input.shape == (2, 3) and axis == 1
# we'll be computing
# [
# [
# {(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)}
# {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)}
# {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)}
# ]
# [
# {(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)}
# {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)}
# {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)}
# ]
# ]
#
# this means to compute output[0, 0] we need to compare
# - input[0, 2, 0]
# - input[0, 0, 0]
# - input[0, 3, 0]
# - input[0, 1, 0]
#
# or to compute output[1, 2] we need to compare
# - input[1, 3, 2]
# - input[1, 2, 2]
# - input[1, 0, 2]
# - input[1, 1, 2]
#
# notice that number of comparisons (say n) are always the same in each output cell
# this opens up the possibility to use fancy indexing to extract n slices from the input
# and compare the slices in a tree-like fashion to obtain the result
#
# in the end, we'll be performing
#
# slice1 = [
# [input[0, 2, 0], input[0, 2, 1], input[0, 1, 2]],
# [input[1, 0, 0], input[1, 2, 1], input[1, 3, 2]],
# ]
# slice2 = [
# [input[0, 0, 0], input[0, 0, 1], input[0, 2, 2]],
# [input[1, 3, 0], input[1, 0, 1], input[1, 2, 2]],
# ]
# slice3 = [
# [input[0, 3, 0], input[0, 3, 1], input[0, 3, 2]],
# [input[1, 1, 0], input[1, 1, 1], input[1, 0, 2]],
# ]
# slice4 = [
# [input[0, 1, 0], input[0, 1, 1], input[0, 0, 2]],
# [input[1, 2, 0], input[1, 3, 1], input[1, 1, 2]],
# ]
#
# minimum_slice1_slice2 = np.minimum(slice1, slice2)
# minimum_slice3_slice4 = np.minimum(slice3, slice4)
#
# result = np.minimum(minimum_slice1_slice2, minimum_slice3_slice4)
#
# notice that all slices have the same shape as the result
class Mock:
"""
Class to track accumulation of the operation.
"""
# list of indices that have been accumulated
indices: Set[Tuple[int, ...]]
# initialize the mock with a starting index
def __init__(self, index: Tuple[int, ...]):
self.indices = {index}
# get the representation of the mock
def __repr__(self) -> str:
return f"{self.indices}"
# combine the indices of the mock with another mock into a new mock
def combine(self, other: "Mock") -> "Mock":
result = deepcopy(self)
for index in other.indices:
result.indices.add(index)
return result
# create the mock input
#
# [[[{(0, 0, 0)} {(0, 0, 1)} {(0, 0, 2)}]
# [{(0, 1, 0)} {(0, 1, 1)} {(0, 1, 2)}]
# [{(0, 2, 0)} {(0, 2, 1)} {(0, 2, 2)}]
# [{(0, 3, 0)} {(0, 3, 1)} {(0, 3, 2)}]]
#
# [[{(1, 0, 0)} {(1, 0, 1)} {(1, 0, 2)}]
# [{(1, 1, 0)} {(1, 1, 1)} {(1, 1, 2)}]
# [{(1, 2, 0)} {(1, 2, 1)} {(1, 2, 2)}]
# [{(1, 3, 0)} {(1, 3, 1)} {(1, 3, 2)}]]]
mock_input = []
for index in np.ndindex(x.shape):
mock_input.append(Mock(index))
mock_input = np.array(mock_input).reshape(x.shape)
# use numpy reduction to compute the mock output
#
# [[{(0, 2, 0), (0, 0, 0), (0, 3, 0), (0, 1, 0)}
# {(0, 2, 1), (0, 0, 1), (0, 3, 1), (0, 1, 1)}
# {(0, 1, 2), (0, 2, 2), (0, 3, 2), (0, 0, 2)}]
# [{(1, 0, 0), (1, 3, 0), (1, 1, 0), (1, 2, 0)}
# {(1, 2, 1), (1, 0, 1), (1, 1, 1), (1, 3, 1)}
# {(1, 3, 2), (1, 2, 2), (1, 0, 2), (1, 1, 2)}]]
mock_output = np.frompyfunc(lambda mock1, mock2: mock1.combine(mock2), 2, 1).reduce(
mock_input,
axis=tuple(axes),
keepdims=keep_dims,
)
# extract a sample mock from the mock output
sample_mock = mock_output.flat[0]
# extract the indices of the sample mock
sample_mock_indices = sample_mock.indices
# compute number of comparisons
number_of_comparisons = len(sample_mock_indices)
# extract a sample index from the sample mock indices
sample_mock_index = next(iter(sample_mock_indices))
# compute the number of indices
number_of_indices = len(sample_mock_index)
# compute the shape of fancy indexing indices
index_shape = resulting_type.shape
# compute the fancy indices to extract for comparison
to_compare = []
for _ in range(number_of_comparisons):
indices = []
for _ in range(number_of_indices):
index = np.zeros(index_shape, dtype=np.int64) # type: ignore
indices.append(index)
to_compare.append(tuple(indices))
for position in np.ndindex(mock_output.shape):
mock_indices = list(mock_output[position].indices)
for i in range(number_of_comparisons):
for j in range(number_of_indices):
to_compare[i][j][position] = mock_indices[i][j] # type: ignore
# to_compare will look like
# [
# # for the first slice
# (
# [[0, 0, 0], [1, 1, 1]], # i
# [[2, 2, 1], [0, 2, 3]], # j
# [[0, 1, 2], [0, 1, 2]], # k
# ),
# # for the second slice
# (
# [[0, 0, 0], [1, 1, 1]], # i
# [[0, 0, 2], [3, 0, 2]], # j
# [[0, 1, 2], [0, 1, 2]], # k
# ),
# # for the third slice
# (
# [[0, 0, 0], [1, 1, 1]], # i
# [[3, 3, 3], [1, 1, 0]], # j
# [[0, 1, 2], [0, 1, 2]], # k
# ),
# # for the fourth slice
# (
# [[0, 0, 0], [1, 1, 1]], # i
# [[1, 1, 0], [2, 3, 1]], # j
# [[0, 1, 2], [0, 1, 2]], # k
# ),
# ]
# find the type of the slices
slices_type = ctx.tensor(ctx.element_typeof(x), shape=resulting_type.shape)
# extract the slices
slices = []
for index in to_compare:
slices.append(ctx.index(slices_type, x, index)) # type: ignore
# while there are more than 1 slices
while len(slices) > 1:
# pop the last two slices
a = slices.pop()
b = slices.pop()
# compare the last two slices
if operation == "min":
c = ctx.minimum(resulting_type, a, b)
else:
c = ctx.maximum(resulting_type, a, b)
# we need to set the original bit width manually
# as minimum/maximum doesn't constraint their output bit width.
c.set_original_bit_width(x.original_bit_width)
# insert the slice back at the beginning of the slice queue
slices.insert(0, c)
# return the result
return slices[0]
def reduce(
ctx: Context,
resulting_type: ConversionType,
values: Conversion,
*,
operation: str,
) -> Conversion:
"""
Reduce a vector of values to its min/max value.
"""
# make sure the operation is valid
assert operation in {"min", "max"}
# make sure the value is valid
assert values.is_tensor
assert len(values.shape) == 1
assert values.is_encrypted
# make sure the resulting type is valid
assert resulting_type.is_scalar
assert resulting_type.is_encrypted
# let's say the vector was [1, 4, 2, 3, 0]
# and we're computing np.min(vector)
# find the element type of the vector = fhe.uint3
values_element_type = ctx.element_typeof(values)
# find the middle of the vector = 2
middle = values.shape[0] // 2
# we'll be splitting the array into two halves
# [1, 4] and [2, 3] in our case
# then we'll compute np.minimum(first_half, second_half)
# [1, 3] in our case
# if the reduction is a scalar, we'll return it
# otherwise, we'll reduce recursively until we obtain a scalar
# 1 in our case
# find the half type of the vector which is
# fhe.tensor[fhe.uint3, 2] in the first iteration
# fhe.uint3 in the last iteration
half_type = (
ctx.tensor(values_element_type, shape=(middle,)) if middle != 1 else values_element_type
)
# find the accumulated type of the vector which is
# fhe.tensor[fhe.uint3, 2] in the first iteration
# fhe.uint3 in the last iteration
accumulated_type = (
ctx.tensor(resulting_type, shape=(middle,)) if middle != 1 else resulting_type
)
# if there is only one element in each half (e.g., vector = [1, 3], halfs = [1], [3])
if middle == 1:
# extract the first element in the vector
first_half = ctx.index(half_type, values, index=[0])
# extract the second element in the vector
second_half = ctx.index(half_type, values, index=[1])
else:
# extract the elements from 0 to middle as the first half
first_half = ctx.index(half_type, values, index=[slice(0, middle)])
# extract the elements from middle to 2*middle as the second half
second_half = ctx.index(half_type, values, index=[slice(middle, 2 * middle)])
# compare halfs
if operation == "min":
# [1, 3] in the first iteration
# 1 in the last iteration
reduced = ctx.minimum(accumulated_type, first_half, second_half)
else:
reduced = ctx.maximum(accumulated_type, first_half, second_half)
# set the original bit width of the reduced so the following operation work as intended
# this is required here since ctx.minimum and ctx.maximum does not constraint output bit width
reduced.set_original_bit_width(values.original_bit_width)
result = (
# if reduced value is a scalar, we end the recursion
reduced
if reduced.is_scalar
# otherwise, we reduce the result of comparison of halfs
else reduce(ctx, resulting_type, reduced, operation=operation)
)
# if we have one more element that wasn't in the halfs
if values.shape[0] % 2 == 1:
# we extract it
last_value = ctx.index(values_element_type, values, index=[-1])
# and compare it with the result we obtained from the halfs
result = (
ctx.minimum(resulting_type, result, last_value)
if operation == "min"
else ctx.maximum(resulting_type, result, last_value)
)
# again, we need to set the original bit width
result.set_original_bit_width(values.original_bit_width)
# here is the visualization of the algorithm
#
# [ 1, 4, 2, 3, 0 ]
#
# [1, 4][2, 3],0
# \ / |
# [1, 3] |
# \ /
# 1 /
# \ /
# 0
#
# it has O(log(n)) - 1 number of tensor comparisons (of sizes n/2, n/4, ...)
# and up to 1 + O(log(n)) number of scalar comparisons (depending on the oddity of n/2, n/4, ..)
return result

View File

@@ -351,7 +351,7 @@ class AdditionalConstraints:
node.properties["strategy"] = strategy
break
def min_max(self, node: Node, preds: List[Node]):
def minimum_maximum(self, node: Node, preds: List[Node]):
assert len(preds) == 2
x = preds[0]
@@ -392,69 +392,83 @@ class AdditionalConstraints:
node.properties["strategy"] = strategy
break
def min_max(self, node: Node, preds: List[Node]):
assert len(preds) == 1
self.minimum_maximum(node, [preds[0], preds[0]])
# ==========
# Operations
# ==========
add = { # noqa: RUF012
add = {
inputs_and_output_share_precision,
}
array = { # noqa: RUF012
amax = {
min_max,
inputs_and_output_share_precision,
}
assign_dynamic = { # noqa: RUF012
amin = {
min_max,
inputs_and_output_share_precision,
}
assign_static = { # noqa: RUF012
array = {
inputs_and_output_share_precision,
}
bitwise_and = { # noqa: RUF012
assign_dynamic = {
inputs_and_output_share_precision,
}
assign_static = {
inputs_and_output_share_precision,
}
bitwise_and = {
all_inputs_are_encrypted: {
bitwise,
},
}
bitwise_or = { # noqa: RUF012
bitwise_or = {
all_inputs_are_encrypted: {
bitwise,
},
}
bitwise_xor = { # noqa: RUF012
bitwise_xor = {
all_inputs_are_encrypted: {
bitwise,
},
}
broadcast_to = { # noqa: RUF012
broadcast_to = {
inputs_and_output_share_precision,
}
concatenate = { # noqa: RUF012
concatenate = {
inputs_and_output_share_precision,
}
conv1d = { # noqa: RUF012
conv1d = {
inputs_and_output_share_precision,
}
conv2d = { # noqa: RUF012
conv2d = {
inputs_and_output_share_precision,
}
conv3d = { # noqa: RUF012
conv3d = {
inputs_and_output_share_precision,
}
copy = { # noqa: RUF012
copy = {
inputs_and_output_share_precision,
}
dot = { # noqa: RUF012
dot = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
@@ -464,51 +478,51 @@ class AdditionalConstraints:
},
}
equal = { # noqa: RUF012
equal = {
all_inputs_are_encrypted: {
comparison,
},
}
expand_dims = { # noqa: RUF012
expand_dims = {
inputs_and_output_share_precision,
}
greater = { # noqa: RUF012
greater = {
all_inputs_are_encrypted: {
comparison,
},
}
greater_equal = { # noqa: RUF012
greater_equal = {
all_inputs_are_encrypted: {
comparison,
},
}
index_static = { # noqa: RUF012
index_static = {
inputs_and_output_share_precision,
}
left_shift = { # noqa: RUF012
left_shift = {
all_inputs_are_encrypted: {
bitwise,
},
}
less = { # noqa: RUF012
less = {
all_inputs_are_encrypted: {
comparison,
},
}
less_equal = { # noqa: RUF012
less_equal = {
all_inputs_are_encrypted: {
comparison,
},
}
matmul = { # noqa: RUF012
matmul = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
@@ -518,34 +532,44 @@ class AdditionalConstraints:
},
}
maximum = { # noqa: RUF012
max = {
min_max,
inputs_and_output_share_precision,
}
maximum = {
all_inputs_are_encrypted: {
min_max,
minimum_maximum,
},
}
maxpool1d = { # noqa: RUF012
maxpool1d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
maxpool2d = { # noqa: RUF012
maxpool2d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
maxpool3d = { # noqa: RUF012
maxpool3d = {
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}
minimum = { # noqa: RUF012
min = {
min_max,
inputs_and_output_share_precision,
}
minimum = {
all_inputs_are_encrypted: {
min_max,
minimum_maximum,
},
}
multiply = { # noqa: RUF012
multiply = {
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
@@ -555,48 +579,48 @@ class AdditionalConstraints:
},
}
negative = { # noqa: RUF012
negative = {
inputs_and_output_share_precision,
}
not_equal = { # noqa: RUF012
not_equal = {
all_inputs_are_encrypted: {
comparison,
},
}
reshape = { # noqa: RUF012
reshape = {
inputs_and_output_share_precision,
}
right_shift = { # noqa: RUF012
right_shift = {
all_inputs_are_encrypted: {
bitwise,
},
}
round_bit_pattern = { # noqa: RUF012
round_bit_pattern = {
has_overflow_protection: {
inputs_and_output_share_precision,
},
}
subtract = { # noqa: RUF012
subtract = {
inputs_and_output_share_precision,
}
sum = { # noqa: RUF012
sum = {
inputs_and_output_share_precision,
}
squeeze = { # noqa: RUF012
squeeze = {
inputs_and_output_share_precision,
}
transpose = { # noqa: RUF012
transpose = {
inputs_and_output_share_precision,
}
truncate_bit_pattern = { # noqa: RUF012
truncate_bit_pattern = {
inputs_and_output_share_precision,
}

View File

@@ -2,7 +2,7 @@
Declaration of various functions and constants related to MLIR conversion.
"""
# pylint: disable=import-error
# pylint: disable=import-error,no-name-in-module
from collections import defaultdict, deque
from copy import deepcopy
@@ -26,7 +26,7 @@ from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Node, Operation
# pylint: enable=import-error
# pylint: enable=import-error,no-name-in-module
class HashableNdarray:

View File

@@ -705,18 +705,27 @@ class Graph:
bounds of each node in the `Graph`
"""
for node in self.graph.nodes():
for node in self.query_nodes(ordered=True):
if node in bounds:
min_bound = bounds[node]["min"]
max_bound = bounds[node]["max"]
node.bounds = (min_bound, max_bound)
node.bounds = (min_bound, max_bound) # type: ignore
new_value = deepcopy(node.output)
if isinstance(min_bound, (np.integer, int)):
assert isinstance(new_value.dtype, Integer)
new_value.dtype.update_to_represent(np.array([min_bound, max_bound]))
if node.operation == Operation.Generic and node.properties["name"] in {
"amin",
"amax",
"min",
"max",
}:
assert isinstance(node.inputs[0].dtype, Integer)
new_value.dtype.is_signed = node.inputs[0].dtype.is_signed
else:
new_value.dtype = {
np.bool_: UnsignedInteger(1),

View File

@@ -281,7 +281,9 @@ class Tracer:
np.logical_or,
np.logical_xor,
np.matmul,
np.max,
np.maximum,
np.min,
np.minimum,
np.mod,
np.multiply,
@@ -331,6 +333,14 @@ class Tracer:
np.expand_dims: {
"axis",
},
np.max: {
"axis",
"keepdims",
},
np.min: {
"axis",
"keepdims",
},
np.ones_like: {
"dtype",
},
@@ -471,6 +481,12 @@ class Tracer:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["shape"] = args[1]
elif func in {np.min, np.max}:
sanitized_args = [self.sanitize(args[0])]
for i, keyword in enumerate(["axis", "out", "keepdims", "initial", "where"]):
position = i + 1
if len(args) > position:
kwargs[keyword] = args[position]
elif func is np.reshape:
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:

View File

@@ -19,7 +19,7 @@ def ast_iterator(root):
while nodes:
current_node = nodes.pop(0)
yield current_node
if hasattr(current_node, "children"):
if hasattr(current_node, "children") and current_node.children is not None:
nodes += current_node.children

View File

@@ -318,6 +318,10 @@ class Helpers:
if i == retries - 1:
message = f"""
Sample
===============
{sample}
Expected Output
===============
{expected}

View File

@@ -1,5 +1,5 @@
"""
Tests of execution of minimum and maximum operations.
Tests of execution of min and max operations.
"""
import random
@@ -11,219 +11,145 @@ from concrete import fhe
from concrete.fhe.dtypes import Integer
from concrete.fhe.values import ValueDescription
cases = [
[
# operation
(
"minimum_optimized_x",
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
),
# bit widths
4,
4,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"minimum_optimized_x",
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
),
# bit widths
4,
4,
# signednesses
True,
False,
# shapes
(2,),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"maximum_optimized_y",
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
),
# bit widths
4,
3,
# signednesses
True,
False,
# shapes
(),
(2, 3),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"maximum_optimized_y",
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
),
# bit widths
4,
3,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
]
cases += [
[
# operation
operation,
# bit widths
1,
1,
# signednesses
lhs_is_signed,
rhs_is_signed,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
]
for lhs_is_signed in [False, True]
for rhs_is_signed in [False, True]
for operation in [
(
"maximum",
lambda x, y: np.maximum(x, y),
),
]
]
cases = [
[
# operation
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
# bit widths
7,
7,
# signednesses
True,
False,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
# bit widths
7,
7,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
]
for lhs_bit_width in range(1, 5):
for rhs_bit_width in range(1, 5):
strategies = []
if lhs_bit_width <= 3 and rhs_bit_width <= 3:
strategies += [
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
fhe.MinMaxStrategy.THREE_TLU_CASTED,
]
else:
strategies += [
fhe.MinMaxStrategy.CHUNKED,
]
for lhs_is_signed in [False, True]:
for rhs_is_signed in [False, True]:
cases += [
[
# operation
operation,
# bit widths
lhs_bit_width,
rhs_bit_width,
# signednesses
lhs_is_signed,
rhs_is_signed,
# shapes
random.choice([(), (2,), (3, 2)]),
random.choice([(), (2,), (3, 2)]),
# strategy
strategy,
]
for operation in [
("minimum", lambda x, y: np.minimum(x, y)),
("maximum", lambda x, y: np.maximum(x, y)),
]
for strategy in strategies
]
cases = []
for operation in [("max", lambda x: np.max(x)), ("min", lambda x: np.min(x))]:
for bit_width in range(1, 5):
for is_signed in [False, True]:
for shape in [(), (4,), (3, 3)]:
for keepdims in [False, True]:
for strategy in [
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
fhe.MinMaxStrategy.THREE_TLU_CASTED,
fhe.MinMaxStrategy.CHUNKED,
]:
cases.append(
[
operation,
bit_width,
is_signed,
shape,
None,
keepdims,
strategy,
],
)
for axis in range(len(shape)):
cases.append(
[
operation,
bit_width,
is_signed,
shape,
axis,
keepdims,
strategy,
],
)
cases.append(
[
operation,
bit_width,
is_signed,
shape,
-1,
keepdims,
strategy,
],
)
if len(shape) == 2:
cases.append(
[
operation,
bit_width,
is_signed,
shape,
(0, 1),
keepdims,
strategy,
],
)
cases.append(
[
operation,
bit_width,
is_signed,
shape,
-2,
keepdims,
strategy,
],
)
if len(shape) == 3:
cases.append(
[
operation,
bit_width,
is_signed,
shape,
(0, 1),
keepdims,
strategy,
],
)
cases.append(
[
operation,
bit_width,
is_signed,
shape,
(0, 2),
keepdims,
strategy,
],
)
cases.append(
[
operation,
bit_width,
is_signed,
shape,
(1, 2),
keepdims,
strategy,
],
)
# pylint: disable=redefined-outer-name
@pytest.mark.parametrize(
"operation,"
"lhs_bit_width,rhs_bit_width,"
"lhs_is_signed,rhs_is_signed,"
"lhs_shape,rhs_shape,"
"strategy",
cases,
"operation,bit_width,is_signed,shape,axis,keepdims,strategy",
random.sample(cases, 100),
)
def test_minimum_maximum(
def test_min_max(
operation,
lhs_bit_width,
rhs_bit_width,
lhs_is_signed,
rhs_is_signed,
lhs_shape,
rhs_shape,
bit_width,
is_signed,
shape,
axis,
keepdims,
strategy,
helpers,
):
"""
Test comparison operations between encrypted integers.
Test np.min/np.max on encrypted values.
"""
name, function = operation
lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width)
rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width)
lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True)
rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True)
dtype = Integer(is_signed=is_signed, bit_width=bit_width)
description = ValueDescription(dtype, shape=shape, is_encrypted=True)
print()
print()
print(
f"{name}({lhs_description}, {rhs_description})"
f"np.{name}({description}, axis={axis}, keepdims={keepdims})"
+ (f" {{{strategy}}}" if strategy is not None else "")
)
print()
print()
parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"}
parameter_encryption_statuses = {"x": "encrypted"}
configuration = helpers.configuration()
if strategy is not None:
@@ -231,59 +157,17 @@ def test_minimum_maximum(
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = [
(
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
)
for _ in range(100)
]
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
samples = [
[
np.zeros(lhs_shape, dtype=np.int64),
np.zeros(rhs_shape, dtype=np.int64),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(),
],
[
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
],
np.zeros(shape, dtype=np.int64),
np.ones(shape, dtype=np.int64) * dtype.min(),
np.ones(shape, dtype=np.int64) * dtype.max(),
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
np.random.randint(dtype.min(), dtype.max() + 1, size=shape),
]
for sample in samples:
helpers.check_execution(circuit, function, sample, retries=5)
def test_internal_signed_tlu_padding(helpers):
"""Test that the signed input LUT is correctly padded in the case of substraction trick."""
inputset = [(i, j) for i in [0, 1] for j in [0, 1]]
@fhe.compiler({"a": "encrypted", "b": "encrypted"})
def min2(a, b):
min_12 = np.minimum(a, b)
return (min_12, a + 3, b + 3)
c = min2.compile(inputset, helpers.configuration())
min_0_1, _, _ = c.encrypt_run_decrypt(0, 1)
assert min_0_1 == 0
# Some extra checks to verify that the test is relevant (substraction trick).
assert c.mlir.count("to_signed") == 2 # check substraction trick is used
assert c.mlir.count("sub_eint") == 1 # check substraction trick is used
assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end
assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle

View File

@@ -0,0 +1,289 @@
"""
Tests of execution of minimum and maximum operations.
"""
import random
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.dtypes import Integer
from concrete.fhe.values import ValueDescription
cases = [
[
# operation
(
"minimum_optimized_x",
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
),
# bit widths
4,
4,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"minimum_optimized_x",
lambda x, y: np.minimum(fhe.hint(x, bit_width=5), y), # type: ignore
),
# bit widths
4,
4,
# signednesses
True,
False,
# shapes
(2,),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"maximum_optimized_y",
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
),
# bit widths
4,
3,
# signednesses
True,
False,
# shapes
(),
(2, 3),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
(
"maximum_optimized_y",
lambda x, y: np.maximum(x, fhe.hint(y, bit_width=4)), # type: ignore
),
# bit widths
4,
3,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
]
cases += [
[
# operation
operation,
# bit widths
1,
1,
# signednesses
lhs_is_signed,
rhs_is_signed,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
]
for lhs_is_signed in [False, True]
for rhs_is_signed in [False, True]
for operation in [
(
"maximum",
lambda x, y: np.maximum(x, y),
),
]
]
cases = [
[
# operation
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
# bit widths
7,
7,
# signednesses
True,
False,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
[
# operation
("maximum_increased_bit_widths", lambda x, y: (np.maximum(x, y), x + 100, y + 100)),
# bit widths
7,
7,
# signednesses
False,
True,
# shapes
(),
(),
# strategy
fhe.MinMaxStrategy.CHUNKED,
],
]
for lhs_bit_width in range(1, 5):
for rhs_bit_width in range(1, 5):
strategies = []
if lhs_bit_width <= 3 and rhs_bit_width <= 3:
strategies += [
fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
fhe.MinMaxStrategy.THREE_TLU_CASTED,
]
else:
strategies += [
fhe.MinMaxStrategy.CHUNKED,
]
for lhs_is_signed in [False, True]:
for rhs_is_signed in [False, True]:
cases += [
[
# operation
operation,
# bit widths
lhs_bit_width,
rhs_bit_width,
# signednesses
lhs_is_signed,
rhs_is_signed,
# shapes
random.choice([(), (2,), (3, 2)]),
random.choice([(), (2,), (3, 2)]),
# strategy
strategy,
]
for operation in [
("minimum", lambda x, y: np.minimum(x, y)),
("maximum", lambda x, y: np.maximum(x, y)),
]
for strategy in strategies
]
# pylint: disable=redefined-outer-name
@pytest.mark.parametrize(
"operation,"
"lhs_bit_width,rhs_bit_width,"
"lhs_is_signed,rhs_is_signed,"
"lhs_shape,rhs_shape,"
"strategy",
cases,
)
def test_minimum_maximum(
operation,
lhs_bit_width,
rhs_bit_width,
lhs_is_signed,
rhs_is_signed,
lhs_shape,
rhs_shape,
strategy,
helpers,
):
"""
Test comparison operations between encrypted integers.
"""
name, function = operation
lhs_dtype = Integer(is_signed=lhs_is_signed, bit_width=lhs_bit_width)
rhs_dtype = Integer(is_signed=rhs_is_signed, bit_width=rhs_bit_width)
lhs_description = ValueDescription(lhs_dtype, shape=lhs_shape, is_encrypted=True)
rhs_description = ValueDescription(rhs_dtype, shape=rhs_shape, is_encrypted=True)
print()
print()
print(
f"{name}({lhs_description}, {rhs_description})"
+ (f" {{{strategy}}}" if strategy is not None else "")
)
print()
print()
parameter_encryption_statuses = {"x": "encrypted", "y": "encrypted"}
configuration = helpers.configuration()
if strategy is not None:
configuration = configuration.fork(min_max_strategy_preference=[strategy])
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = [
(
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
)
for _ in range(100)
]
circuit = compiler.compile(inputset, configuration)
samples = [
[
np.zeros(lhs_shape, dtype=np.int64),
np.zeros(rhs_shape, dtype=np.int64),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.min(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.min(),
],
[
np.ones(lhs_shape, dtype=np.int64) * lhs_dtype.max(),
np.ones(rhs_shape, dtype=np.int64) * rhs_dtype.max(),
],
[
np.random.randint(lhs_dtype.min(), lhs_dtype.max() + 1, size=lhs_shape),
np.random.randint(rhs_dtype.min(), rhs_dtype.max() + 1, size=rhs_shape),
],
]
for sample in samples:
helpers.check_execution(circuit, function, sample, retries=5)
def test_internal_signed_tlu_padding(helpers):
"""Test that the signed input LUT is correctly padded in the case of substraction trick."""
inputset = [(i, j) for i in [0, 1] for j in [0, 1]]
@fhe.compiler({"a": "encrypted", "b": "encrypted"})
def min2(a, b):
min_12 = np.minimum(a, b)
return (min_12, a + 3, b + 3)
c = min2.compile(inputset, helpers.configuration())
min_0_1, _, _ = c.encrypt_run_decrypt(0, 1)
assert min_0_1 == 0
# Some extra checks to verify that the test is relevant (substraction trick).
assert c.mlir.count("to_signed") == 2 # check substraction trick is used
assert c.mlir.count("sub_eint") == 1 # check substraction trick is used
assert c.mlir.count("<[0, 0, -2, -1, 0, 0, 0, 0]>") == 0 # lut wrongly padded at the end
assert c.mlir.count("<[0, 0, 0, 0, 0, 0, -2, -1]>") == 1 # lut correctly padded in the middle