Files
concrete/frontends/concrete-python/tests/execution/test_tfhers.py
2024-05-28 14:44:20 +01:00

196 lines
5.9 KiB
Python

"""
Tests execution of tfhers conversion operations.
"""
import pytest
from concrete import fhe
from concrete.fhe import tfhers
def binary_tfhers(x, y, binary_op, tfhers_type):
"""wrap binary op in tfhers conversion (2 tfhers inputs)"""
x = tfhers.to_native(x)
y = tfhers.to_native(y)
return tfhers.from_native(binary_op(x, y), tfhers_type)
def one_tfhers_one_native(x, y, binary_op, tfhers_type):
"""wrap binary op in tfhers conversion (1 tfhers, 1 native input)"""
x = tfhers.to_native(x)
return tfhers.from_native(binary_op(x, y), tfhers_type)
@pytest.mark.parametrize(
"function, parameters, dtype",
[
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y big values",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x - y",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
id="x * y",
),
],
)
def test_tfhers_conversion_binary_encrypted(
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
):
"""
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(
lambda x, y: binary_tfhers(x, y, function, dtype),
parameter_encryption_statuses,
)
inputset = [
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
for inpt in helpers.generate_inputset(parameters)
]
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
encoded_sample = (dtype.encode(v) for v in sample)
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
assert (dtype.decode(encoded_result) == function(*sample)).all()
@pytest.mark.parametrize(
"function, parameters, dtype",
[
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**14], "status": "clear"},
},
tfhers.uint16_2_2,
id="x + clear(y)",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x + y big values",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
"y": {"range": [2**14, 2**15 - 1], "status": "clear"},
},
tfhers.uint16_2_2,
id="x + clear(y) big values",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "encrypted"},
},
tfhers.uint16_2_2,
id="x - y",
),
pytest.param(
lambda x, y: x - y,
{
"x": {"range": [2**10, 2**14], "status": "encrypted"},
"y": {"range": [0, 2**10], "status": "clear"},
},
tfhers.uint16_2_2,
id="x - clear(y)",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
id="x * y",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "clear"},
},
tfhers.uint8_2_2,
id="x * clear(y)",
),
],
)
def test_tfhers_conversion_one_encrypted_one_native(
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
):
"""
Test different operations wrapped by tfhers conversion (1 tfhers, 1 native input).
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(
lambda x, y: one_tfhers_one_native(x, y, function, dtype),
parameter_encryption_statuses,
)
inputset = [
(tfhers.TFHERSInteger(dtype, inpt[0]), inpt[1])
for inpt in helpers.generate_inputset(parameters)
]
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
encoded_sample = (dtype.encode(sample[0]), sample[1])
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
assert (dtype.decode(encoded_result) == function(*sample)).all()