""" Tests execution of tfhers conversion operations. """ import numpy as np import pytest from concrete import fhe from concrete.fhe import tfhers def binary_tfhers(x, y, binary_op, tfhers_type): """wrap binary op in tfhers conversion (2 tfhers inputs)""" x = tfhers.to_native(x) y = tfhers.to_native(y) return tfhers.from_native(binary_op(x, y), tfhers_type) def one_tfhers_one_native(x, y, binary_op, tfhers_type): """wrap binary op in tfhers conversion (1 tfhers, 1 native input)""" x = tfhers.to_native(x) return tfhers.from_native(binary_op(x, y), tfhers_type) @pytest.mark.parametrize( "function, parameters, dtype", [ pytest.param( lambda x, y: x + y, { "x": {"range": [0, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**14], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x + y", ), pytest.param( lambda x, y: x + y, { "x": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, "y": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x + y big values", ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**10, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**10], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x - y", ), pytest.param( lambda x, y: x * y, { "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "encrypted"}, }, tfhers.uint8_2_2, id="x * y", ), ], ) def test_tfhers_conversion_binary_encrypted( function, parameters, dtype: tfhers.TFHERSIntegerType, helpers ): """ Test different operations wrapped by tfhers conversion (2 tfhers inputs). """ parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() compiler = fhe.Compiler( lambda x, y: binary_tfhers(x, y, function, dtype), parameter_encryption_statuses, ) inputset = [ tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt) for inpt in helpers.generate_inputset(parameters) ] circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) encoded_sample = (dtype.encode(v) for v in sample) encoded_result = circuit.encrypt_run_decrypt(*encoded_sample) assert (dtype.decode(encoded_result) == function(*sample)).all() @pytest.mark.parametrize( "function, parameters, dtype", [ pytest.param( lambda x, y: x + y, { "x": {"range": [0, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**14], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x + y", ), pytest.param( lambda x, y: x + y, { "x": {"range": [0, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**14], "status": "clear"}, }, tfhers.uint16_2_2, id="x + clear(y)", ), pytest.param( lambda x, y: x + y, { "x": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, "y": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x + y big values", ), pytest.param( lambda x, y: x + y, { "x": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, "y": {"range": [2**14, 2**15 - 1], "status": "clear"}, }, tfhers.uint16_2_2, id="x + clear(y) big values", ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**10, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**10], "status": "encrypted"}, }, tfhers.uint16_2_2, id="x - y", ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**10, 2**14], "status": "encrypted"}, "y": {"range": [0, 2**10], "status": "clear"}, }, tfhers.uint16_2_2, id="x - clear(y)", ), pytest.param( lambda x, y: x * y, { "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "encrypted"}, }, tfhers.uint8_2_2, id="x * y", ), pytest.param( lambda x, y: x * y, { "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "clear"}, }, tfhers.uint8_2_2, id="x * clear(y)", ), ], ) def test_tfhers_conversion_one_encrypted_one_native( function, parameters, dtype: tfhers.TFHERSIntegerType, helpers ): """ Test different operations wrapped by tfhers conversion (1 tfhers, 1 native input). """ parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() compiler = fhe.Compiler( lambda x, y: one_tfhers_one_native(x, y, function, dtype), parameter_encryption_statuses, ) inputset = [ (tfhers.TFHERSInteger(dtype, inpt[0]), inpt[1]) for inpt in helpers.generate_inputset(parameters) ] circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) encoded_sample = (dtype.encode(sample[0]), sample[1]) encoded_result = circuit.encrypt_run_decrypt(*encoded_sample) assert (dtype.decode(encoded_result) == function(*sample)).all() @pytest.mark.parametrize( "lhs,rhs,is_equal", [ pytest.param( tfhers.uint16_2_2, tfhers.uint16_2_2, True, ), pytest.param( tfhers.uint16_2_2, tfhers.uint8_2_2, False, ), ], ) def test_tfhers_integer_eq(lhs, rhs, is_equal): """ Test TFHERSIntegerType equality. """ assert is_equal == (lhs == rhs) @pytest.mark.parametrize( "dtype,value,encoded", [ pytest.param( tfhers.uint16_2_2, [10, 20, 30], [ [0, 0, 0, 0, 0, 0, 2, 2], [0, 0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 1, 3, 2], ], ), ], ) def test_tfhers_integer_encode(dtype, value, encoded): """ Test TFHERSIntegerType encode. """ assert np.array_equal(dtype.encode(value), encoded) @pytest.mark.parametrize( "dtype,value,expected_error,expected_message", [ pytest.param( tfhers.uint16_2_2, "foo", TypeError, "can only encode int, np.integer, list or ndarray, but got ", ), ], ) def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_message): """ Test TFHERSIntegerType encode. """ with pytest.raises(expected_error) as excinfo: dtype.encode(value) assert str(excinfo.value) == expected_message @pytest.mark.parametrize( "dtype,encoded,decoded", [ pytest.param( tfhers.uint16_2_2, [ [0, 0, 0, 0, 0, 0, 2, 2], [0, 0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 1, 3, 2], ], [10, 20, 30], ), ], ) def test_tfhers_integer_decode(dtype, encoded, decoded): """ Test TFHERSIntegerType decode. """ assert np.array_equal(dtype.decode(encoded), decoded) @pytest.mark.parametrize( "dtype,value,expected_error,expected_message", [ pytest.param( tfhers.uint16_2_2, "foo", TypeError, "can only decode list of integers or ndarray of integers, but got ", ), pytest.param( tfhers.uint16_2_2, [1, 2, 3], ValueError, "expected the last dimension of encoded value to be 8 but it's 3", ), ], ) def test_tfhers_integer_bad_decode(dtype, value, expected_error, expected_message): """ Test TFHERSIntegerType decode. """ with pytest.raises(expected_error) as excinfo: dtype.decode(value) assert str(excinfo.value) == expected_message @pytest.mark.parametrize( "dtype,value,encoded", [ pytest.param( tfhers.uint16_2_2, 10, 10, ), pytest.param( tfhers.uint16_2_2, 20, 20, ), ], ) def test_tfhers_from_native_outside_tracing(dtype, value, encoded): """ Test tfhers.from_native outside tracing. """ assert np.array_equal(tfhers.from_native(value, dtype), encoded) @pytest.mark.parametrize( "value,decoded", [ pytest.param( tfhers.TFHERSInteger(tfhers.uint8_2_2, [0, 0, 2, 2]), [0, 0, 2, 2], ), ], ) def test_tfhers_to_native_outside_tracing(value, decoded): """ Test tfhers.to_native outside tracing. """ print(tfhers.to_native(value)) assert np.array_equal(tfhers.to_native(value), decoded) @pytest.mark.parametrize( "value,expected_error,expected_message", [ pytest.param( "foo", ValueError, "tfhers.to_native should be called with a TFHERSInteger", ), ], ) def test_tfhers_bad_to_native(value, expected_error, expected_message): """ Test tfhers.to_native with bad arguments. """ with pytest.raises(expected_error) as excinfo: tfhers.to_native(value) assert str(excinfo.value) == expected_message @pytest.mark.parametrize( "dtype,value,expected_error,expected_message", [ pytest.param( tfhers.uint16_2_2, [[1], [2, 3, 4], [5, 6]], ValueError, "got error while trying to convert list value into a numpy array: " "setting an array element with a sequence. The requested array has " "an inhomogeneous shape after 1 dimensions. The detected shape was " "(3,) + inhomogeneous part.", ), pytest.param( tfhers.uint16_2_2, [100_000], ValueError, "ndarray value has bigger elements than what the dtype can support", ), pytest.param( tfhers.uint16_2_2, [-100_000], ValueError, "ndarray value has smaller elements than what the dtype can support", ), pytest.param( tfhers.uint16_2_2, "foo", TypeError, "value can either be an int or ndarray, not a ", ), ], ) def test_tfhers_integer_bad_init(dtype, value, expected_error, expected_message): """ Test __init__ of TFHERSInteger with bad arguments. """ with pytest.raises(expected_error) as excinfo: tfhers.TFHERSInteger(dtype, value) assert str(excinfo.value) == expected_message