mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-14 07:17:57 -05:00
133 lines
3.7 KiB
Python
133 lines
3.7 KiB
Python
"""
|
|
Tests of execution of `if_then_else` extension.
|
|
"""
|
|
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from concrete import fhe
|
|
from concrete.fhe.dtypes import Integer
|
|
from concrete.fhe.values import EncryptedScalar, EncryptedTensor
|
|
|
|
# pylint: disable=redefined-outer-name
|
|
|
|
functions = [
|
|
lambda condition, when_true, when_false: np.where(condition, when_true, when_false),
|
|
lambda condition, when_true, when_false: np.where(condition, when_true, when_false) + 100,
|
|
]
|
|
condition_descriptions = [
|
|
EncryptedTensor(Integer(is_signed=False, bit_width=1), shape=shape)
|
|
for shape in [(), (2,), (3, 2)]
|
|
]
|
|
when_true_descriptions = [
|
|
EncryptedTensor(Integer(is_signed, bit_width), shape=shape)
|
|
for is_signed in [False, True]
|
|
for bit_width in [3, 4, 5]
|
|
for shape in [(), (2,), (3, 2)]
|
|
]
|
|
when_false_descriptions = [
|
|
EncryptedTensor(Integer(is_signed, bit_width), shape=shape)
|
|
for is_signed in [False, True]
|
|
for bit_width in [3, 4, 5]
|
|
for shape in [(), (2,), (3, 2)]
|
|
]
|
|
chunk_sizes = [
|
|
2,
|
|
3,
|
|
]
|
|
|
|
cases = []
|
|
for function in functions:
|
|
for condition_description in condition_descriptions:
|
|
for when_true_description in when_true_descriptions:
|
|
for when_false_description in when_false_descriptions:
|
|
for chunk_size in chunk_sizes:
|
|
cases.append(
|
|
(
|
|
function,
|
|
condition_description,
|
|
when_true_description,
|
|
when_false_description,
|
|
chunk_size,
|
|
)
|
|
)
|
|
|
|
cases = random.sample(cases, 100)
|
|
cases.append(
|
|
(
|
|
# special case of increased bit-width for condition
|
|
lambda condition, when_true, when_false: (
|
|
np.where(condition, when_true, when_false) + (condition + 100)
|
|
),
|
|
EncryptedScalar(Integer(is_signed=False, bit_width=1)),
|
|
EncryptedScalar(Integer(is_signed=False, bit_width=4)),
|
|
EncryptedScalar(Integer(is_signed=False, bit_width=4)),
|
|
2,
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function,condition_description,when_true_description,when_false_description,chunk_size",
|
|
cases,
|
|
)
|
|
def test_if_then_else(
|
|
function,
|
|
condition_description,
|
|
when_true_description,
|
|
when_false_description,
|
|
chunk_size,
|
|
helpers,
|
|
):
|
|
"""
|
|
Test encrypted evaluation of `if_then_else` extension.
|
|
"""
|
|
|
|
print()
|
|
print()
|
|
print(
|
|
f"[{when_true_description}] "
|
|
f"if [{condition_description}] "
|
|
f"else [{when_false_description}] "
|
|
f"{{{chunk_size=}}}"
|
|
)
|
|
print()
|
|
print()
|
|
|
|
inputset = [
|
|
(
|
|
np.random.randint(
|
|
condition_description.dtype.min(),
|
|
condition_description.dtype.max() + 1,
|
|
size=condition_description.shape,
|
|
),
|
|
np.random.randint(
|
|
when_true_description.dtype.min(),
|
|
when_true_description.dtype.max() + 1,
|
|
size=when_true_description.shape,
|
|
),
|
|
np.random.randint(
|
|
when_false_description.dtype.min(),
|
|
when_false_description.dtype.max() + 1,
|
|
size=when_false_description.shape,
|
|
),
|
|
)
|
|
for _ in range(100)
|
|
]
|
|
configuration = helpers.configuration().fork(if_then_else_chunk_size=chunk_size)
|
|
|
|
compiler = fhe.Compiler(
|
|
function,
|
|
{
|
|
"condition": "encrypted",
|
|
"when_true": "encrypted",
|
|
"when_false": "encrypted",
|
|
},
|
|
)
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
for sample in random.sample(inputset, 8):
|
|
helpers.check_execution(circuit, function, list(sample), retries=3)
|