Files
concrete/frontends/concrete-python/tests/execution/test_if_then_else.py

133 lines
3.7 KiB
Python

"""
Tests of execution of `if_then_else` extension.
"""
import random
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.dtypes import Integer
from concrete.fhe.values import EncryptedScalar, EncryptedTensor
# pylint: disable=redefined-outer-name
functions = [
lambda condition, when_true, when_false: np.where(condition, when_true, when_false),
lambda condition, when_true, when_false: np.where(condition, when_true, when_false) + 100,
]
condition_descriptions = [
EncryptedTensor(Integer(is_signed=False, bit_width=1), shape=shape)
for shape in [(), (2,), (3, 2)]
]
when_true_descriptions = [
EncryptedTensor(Integer(is_signed, bit_width), shape=shape)
for is_signed in [False, True]
for bit_width in [3, 4, 5]
for shape in [(), (2,), (3, 2)]
]
when_false_descriptions = [
EncryptedTensor(Integer(is_signed, bit_width), shape=shape)
for is_signed in [False, True]
for bit_width in [3, 4, 5]
for shape in [(), (2,), (3, 2)]
]
chunk_sizes = [
2,
3,
]
cases = []
for function in functions:
for condition_description in condition_descriptions:
for when_true_description in when_true_descriptions:
for when_false_description in when_false_descriptions:
for chunk_size in chunk_sizes:
cases.append(
(
function,
condition_description,
when_true_description,
when_false_description,
chunk_size,
)
)
cases = random.sample(cases, 100)
cases.append(
(
# special case of increased bit-width for condition
lambda condition, when_true, when_false: (
np.where(condition, when_true, when_false) + (condition + 100)
),
EncryptedScalar(Integer(is_signed=False, bit_width=1)),
EncryptedScalar(Integer(is_signed=False, bit_width=4)),
EncryptedScalar(Integer(is_signed=False, bit_width=4)),
2,
)
)
@pytest.mark.parametrize(
"function,condition_description,when_true_description,when_false_description,chunk_size",
cases,
)
def test_if_then_else(
function,
condition_description,
when_true_description,
when_false_description,
chunk_size,
helpers,
):
"""
Test encrypted evaluation of `if_then_else` extension.
"""
print()
print()
print(
f"[{when_true_description}] "
f"if [{condition_description}] "
f"else [{when_false_description}] "
f"{{{chunk_size=}}}"
)
print()
print()
inputset = [
(
np.random.randint(
condition_description.dtype.min(),
condition_description.dtype.max() + 1,
size=condition_description.shape,
),
np.random.randint(
when_true_description.dtype.min(),
when_true_description.dtype.max() + 1,
size=when_true_description.shape,
),
np.random.randint(
when_false_description.dtype.min(),
when_false_description.dtype.max() + 1,
size=when_false_description.shape,
),
)
for _ in range(100)
]
configuration = helpers.configuration().fork(if_then_else_chunk_size=chunk_size)
compiler = fhe.Compiler(
function,
{
"condition": "encrypted",
"when_true": "encrypted",
"when_false": "encrypted",
},
)
circuit = compiler.compile(inputset, configuration)
for sample in random.sample(inputset, 8):
helpers.check_execution(circuit, function, list(sample), retries=3)