mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-19 09:48:30 -05:00
418 lines
11 KiB
Python
418 lines
11 KiB
Python
"""
|
|
Tests execution of tfhers conversion operations.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from concrete import fhe
|
|
from concrete.fhe import tfhers
|
|
|
|
|
|
def binary_tfhers(x, y, binary_op, tfhers_type):
|
|
"""wrap binary op in tfhers conversion (2 tfhers inputs)"""
|
|
x = tfhers.to_native(x)
|
|
y = tfhers.to_native(y)
|
|
return tfhers.from_native(binary_op(x, y), tfhers_type)
|
|
|
|
|
|
def one_tfhers_one_native(x, y, binary_op, tfhers_type):
|
|
"""wrap binary op in tfhers conversion (1 tfhers, 1 native input)"""
|
|
x = tfhers.to_native(x)
|
|
return tfhers.from_native(binary_op(x, y), tfhers_type)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function, parameters, dtype",
|
|
[
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [0, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**14], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + y",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
|
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + y big values",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x - y,
|
|
{
|
|
"x": {"range": [2**10, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**10], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x - y",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x * y,
|
|
{
|
|
"x": {"range": [0, 2**3], "status": "encrypted"},
|
|
"y": {"range": [0, 2**3], "status": "encrypted"},
|
|
},
|
|
tfhers.uint8_2_2,
|
|
id="x * y",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_conversion_binary_encrypted(
|
|
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
|
|
):
|
|
"""
|
|
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
|
|
"""
|
|
|
|
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
|
configuration = helpers.configuration()
|
|
|
|
compiler = fhe.Compiler(
|
|
lambda x, y: binary_tfhers(x, y, function, dtype),
|
|
parameter_encryption_statuses,
|
|
)
|
|
|
|
inputset = [
|
|
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
|
|
for inpt in helpers.generate_inputset(parameters)
|
|
]
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
sample = helpers.generate_sample(parameters)
|
|
encoded_sample = (dtype.encode(v) for v in sample)
|
|
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
|
|
|
|
assert (dtype.decode(encoded_result) == function(*sample)).all()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function, parameters, dtype",
|
|
[
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [0, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**14], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + y",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [0, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**14], "status": "clear"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + clear(y)",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
|
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + y big values",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x + y,
|
|
{
|
|
"x": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
|
"y": {"range": [2**14, 2**15 - 1], "status": "clear"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x + clear(y) big values",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x - y,
|
|
{
|
|
"x": {"range": [2**10, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**10], "status": "encrypted"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x - y",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x - y,
|
|
{
|
|
"x": {"range": [2**10, 2**14], "status": "encrypted"},
|
|
"y": {"range": [0, 2**10], "status": "clear"},
|
|
},
|
|
tfhers.uint16_2_2,
|
|
id="x - clear(y)",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x * y,
|
|
{
|
|
"x": {"range": [0, 2**3], "status": "encrypted"},
|
|
"y": {"range": [0, 2**3], "status": "encrypted"},
|
|
},
|
|
tfhers.uint8_2_2,
|
|
id="x * y",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x * y,
|
|
{
|
|
"x": {"range": [0, 2**3], "status": "encrypted"},
|
|
"y": {"range": [0, 2**3], "status": "clear"},
|
|
},
|
|
tfhers.uint8_2_2,
|
|
id="x * clear(y)",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_conversion_one_encrypted_one_native(
|
|
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
|
|
):
|
|
"""
|
|
Test different operations wrapped by tfhers conversion (1 tfhers, 1 native input).
|
|
"""
|
|
|
|
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
|
configuration = helpers.configuration()
|
|
|
|
compiler = fhe.Compiler(
|
|
lambda x, y: one_tfhers_one_native(x, y, function, dtype),
|
|
parameter_encryption_statuses,
|
|
)
|
|
|
|
inputset = [
|
|
(tfhers.TFHERSInteger(dtype, inpt[0]), inpt[1])
|
|
for inpt in helpers.generate_inputset(parameters)
|
|
]
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
sample = helpers.generate_sample(parameters)
|
|
encoded_sample = (dtype.encode(sample[0]), sample[1])
|
|
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
|
|
|
|
assert (dtype.decode(encoded_result) == function(*sample)).all()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"lhs,rhs,is_equal",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
tfhers.uint16_2_2,
|
|
True,
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
tfhers.uint8_2_2,
|
|
False,
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_eq(lhs, rhs, is_equal):
|
|
"""
|
|
Test TFHERSIntegerType equality.
|
|
"""
|
|
assert is_equal == (lhs == rhs)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,value,encoded",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[10, 20, 30],
|
|
[
|
|
[0, 0, 0, 0, 0, 0, 2, 2],
|
|
[0, 0, 0, 0, 0, 1, 1, 0],
|
|
[0, 0, 0, 0, 0, 1, 3, 2],
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_encode(dtype, value, encoded):
|
|
"""
|
|
Test TFHERSIntegerType encode.
|
|
"""
|
|
|
|
assert np.array_equal(dtype.encode(value), encoded)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,value,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
"foo",
|
|
TypeError,
|
|
"can only encode int, np.integer, list or ndarray, but got <class 'str'>",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_message):
|
|
"""
|
|
Test TFHERSIntegerType encode.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
dtype.encode(value)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,encoded,decoded",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[
|
|
[0, 0, 0, 0, 0, 0, 2, 2],
|
|
[0, 0, 0, 0, 0, 1, 1, 0],
|
|
[0, 0, 0, 0, 0, 1, 3, 2],
|
|
],
|
|
[10, 20, 30],
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_decode(dtype, encoded, decoded):
|
|
"""
|
|
Test TFHERSIntegerType decode.
|
|
"""
|
|
|
|
assert np.array_equal(dtype.decode(encoded), decoded)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,value,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
"foo",
|
|
TypeError,
|
|
"can only decode list of integers or ndarray of integers, but got <class 'str'>",
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[1, 2, 3],
|
|
ValueError,
|
|
"expected the last dimension of encoded value to be 8 but it's 3",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_bad_decode(dtype, value, expected_error, expected_message):
|
|
"""
|
|
Test TFHERSIntegerType decode.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
dtype.decode(value)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,value,encoded",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
10,
|
|
10,
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
20,
|
|
20,
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_from_native_outside_tracing(dtype, value, encoded):
|
|
"""
|
|
Test tfhers.from_native outside tracing.
|
|
"""
|
|
|
|
assert np.array_equal(tfhers.from_native(value, dtype), encoded)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"value,decoded",
|
|
[
|
|
pytest.param(
|
|
tfhers.TFHERSInteger(tfhers.uint8_2_2, [0, 0, 2, 2]),
|
|
[0, 0, 2, 2],
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_to_native_outside_tracing(value, decoded):
|
|
"""
|
|
Test tfhers.to_native outside tracing.
|
|
"""
|
|
|
|
print(tfhers.to_native(value))
|
|
assert np.array_equal(tfhers.to_native(value), decoded)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"value,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
"foo",
|
|
ValueError,
|
|
"tfhers.to_native should be called with a TFHERSInteger",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_bad_to_native(value, expected_error, expected_message):
|
|
"""
|
|
Test tfhers.to_native with bad arguments.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
tfhers.to_native(value)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype,value,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[[1], [2, 3, 4], [5, 6]],
|
|
ValueError,
|
|
"got error while trying to convert list value into a numpy array: "
|
|
"setting an array element with a sequence. The requested array has "
|
|
"an inhomogeneous shape after 1 dimensions. The detected shape was "
|
|
"(3,) + inhomogeneous part.",
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[100_000],
|
|
ValueError,
|
|
"ndarray value has bigger elements than what the dtype can support",
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
[-100_000],
|
|
ValueError,
|
|
"ndarray value has smaller elements than what the dtype can support",
|
|
),
|
|
pytest.param(
|
|
tfhers.uint16_2_2,
|
|
"foo",
|
|
TypeError,
|
|
"value can either be an int or ndarray, not a <class 'str'>",
|
|
),
|
|
],
|
|
)
|
|
def test_tfhers_integer_bad_init(dtype, value, expected_error, expected_message):
|
|
"""
|
|
Test __init__ of TFHERSInteger with bad arguments.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
tfhers.TFHERSInteger(dtype, value)
|
|
|
|
assert str(excinfo.value) == expected_message
|