mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-13 06:48:02 -05:00
190 lines
5.2 KiB
Python
190 lines
5.2 KiB
Python
"""
|
|
Tests of execution of multivariate extension.
|
|
"""
|
|
|
|
import inspect
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from concrete import fhe
|
|
from concrete.fhe.dtypes import Integer
|
|
from concrete.fhe.values import ValueDescription
|
|
|
|
|
|
def x_if_y_else_zero(x, y):
|
|
"""
|
|
Multivariate function for keeping the value of x
|
|
or setting it to 0, depending on the value of y.
|
|
"""
|
|
return np.where(y, x, 0)
|
|
|
|
|
|
def multi(x, y, z):
|
|
"""
|
|
Multivariate function that results in multiple tables in the table lookup.
|
|
"""
|
|
result = x * y
|
|
result[0] -= z
|
|
result[2] += z
|
|
return result
|
|
|
|
|
|
cases = [
|
|
[
|
|
("x_if_y_else_zero", lambda x, y: fhe.multivariate(x_if_y_else_zero)(x, y)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=3),
|
|
shape=(),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=1),
|
|
shape=(),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.CASTED,
|
|
],
|
|
[
|
|
("x_if_y_else_zero", lambda x, y: fhe.multivariate(x_if_y_else_zero)(x, y)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=3),
|
|
shape=(2,),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=1),
|
|
shape=(),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.PROMOTED,
|
|
],
|
|
[
|
|
("x_if_y_else_zero", lambda x, y: fhe.multivariate(x_if_y_else_zero)(x, y)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=True, bit_width=3),
|
|
shape=(),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=1),
|
|
shape=(2,),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.CASTED,
|
|
],
|
|
[
|
|
("x_if_y_else_zero", lambda x, y: fhe.multivariate(x_if_y_else_zero)(x, y)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=True, bit_width=3),
|
|
shape=(2,),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=1),
|
|
shape=(2,),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.PROMOTED,
|
|
],
|
|
[
|
|
("multi", lambda x, y, z: fhe.multivariate(multi)(x, y, z)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=2),
|
|
shape=(3,),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=True, bit_width=2),
|
|
shape=(3,),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=2),
|
|
shape=(),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.CASTED,
|
|
],
|
|
[
|
|
("multi", lambda x, y, z: fhe.multivariate(multi)(x, y, z)),
|
|
[
|
|
ValueDescription(
|
|
Integer(is_signed=True, bit_width=2),
|
|
shape=(3, 2),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=False, bit_width=2),
|
|
shape=(1, 2),
|
|
is_encrypted=True,
|
|
),
|
|
ValueDescription(
|
|
Integer(is_signed=True, bit_width=2),
|
|
shape=(2,),
|
|
is_encrypted=True,
|
|
),
|
|
],
|
|
fhe.MultivariateStrategy.PROMOTED,
|
|
],
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("operation,values,strategy", cases)
|
|
def test_multivariate(operation, values, strategy, helpers):
|
|
"""
|
|
Test multivariate extension.
|
|
"""
|
|
|
|
name, function = operation
|
|
|
|
print()
|
|
print()
|
|
print(
|
|
f"{name}({', '.join(str(value) for value in values)})"
|
|
+ (f" {{{strategy}}}" if strategy is not None else "")
|
|
)
|
|
print()
|
|
print()
|
|
|
|
signature = inspect.signature(function)
|
|
parameters = list(signature.parameters.keys())
|
|
|
|
parameter_encryption_statuses = {parameter: "encrypted" for parameter in parameters}
|
|
configuration = helpers.configuration()
|
|
|
|
if strategy is not None:
|
|
configuration = configuration.fork(multivariate_strategy_preference=strategy)
|
|
|
|
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
|
|
|
inputset = [
|
|
tuple(
|
|
np.random.randint(value.dtype.min(), value.dtype.max() + 1, size=value.shape)
|
|
for value in values
|
|
)
|
|
for _ in range(100)
|
|
]
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
samples = [
|
|
[np.zeros(value.shape, dtype=np.int64) for value in values],
|
|
[np.ones(value.shape, dtype=np.int64) * value.dtype.max() for value in values],
|
|
[
|
|
np.random.randint(value.dtype.min(), value.dtype.max() + 1, size=value.shape)
|
|
for value in values
|
|
],
|
|
]
|
|
for sample in samples:
|
|
helpers.check_execution(circuit, function, sample, retries=5)
|