diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 5c48dd372..8711d58f3 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -1018,23 +1018,23 @@ class Converter: num_cts = output_tfhers_bit_width // msg_width # number of ciphertexts representing the actual part of the message num_full_cts_msg = native_int.bit_width // msg_width - num_half_cts_msg = 0 if native_int.bit_width % msg_width == 0 else 1 + remaining_bits = native_int.bit_width % msg_width + num_half_cts_msg = 0 if remaining_bits == 0 else 1 num_cts_msg = num_full_cts_msg + num_half_cts_msg # adds a dimension of ciphertexts for the result result_shape_msg = native_int.shape + (num_cts_msg,) - result_type_msg = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape_msg) # we reshape so that we can concatenate later over the last dim (ciphertext dim) reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,)) # we want to extract `msg_width` bits at a time, and store them # in a `msg_width + carry_width` bits eint - bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape) + bits_type = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape) # we extract lsb first extracted_bits = [ ctx.extract_bits( - bits_shape, + bits_type, reshaped_native_int, bits=slice(i * msg_width, (i + 1) * msg_width, 1), ) @@ -1042,32 +1042,97 @@ class Converter: ] # we need to extract the remaining bits if any if num_half_cts_msg: + # in the unsigned case, we won't do anything to this half full ciphertext: so we target + # the final bitwidth directly + target_bitwidth = remaining_bits if dtype.is_signed else (msg_width + carry_width) + remaining_bits_type = ctx.tensor( + ctx.eint(target_bitwidth), + reshaped_native_int.shape, + ) extracted_bits.append( ctx.extract_bits( - bits_shape, + remaining_bits_type, reshaped_native_int, bits=slice( num_full_cts_msg * msg_width, - num_full_cts_msg * msg_width + native_int.bit_width % msg_width, + num_full_cts_msg * msg_width + remaining_bits, 1, ), ) ) - result = ctx.concatenate(result_type_msg, extracted_bits, axis=-1) + # if we expect more ciphertexts than we have in the message: we pad the result + # if we have one non-full ct (when signed): we extend the sign bit + if num_cts_msg != num_cts or (num_half_cts_msg and dtype.is_signed): + # if it's signed then we have to move the sign bit to the msb + if dtype.is_signed: + # special case where the sign_bit is already isolated into a single ciphertext + if remaining_bits == 1: + sign_bits = extracted_bits.pop() + else: # remaining_bits either > 1 or == 0 + msbs = extracted_bits.pop() + sign_bit_idx = (remaining_bits - 1) % msg_width + sign_bits = ctx.extract_bits( + ctx.tensor(ctx.eint(1), reshaped_native_int.shape), + msbs, + bits=sign_bit_idx, + ) + if remaining_bits > 1: + # extend sign bit in msbs + # For example: we have this message |0000|0110| where sign_bit_idx is 2 + # we want the result to be |0000|1110| extending the sign bit (1) to the + # remaining MSBs + sign_extension = 2**msg_width - 1 - (2 ** (sign_bit_idx + 1) - 1) + extend_sign_bit_table = [ + i + sign_extension if i & 2**sign_bit_idx else i + for i in range(2**remaining_bits) + ] + msbs_with_extended_sign = ctx.tlu(bits_type, msbs, extend_sign_bit_table) + extracted_bits.append(msbs_with_extended_sign) + else: # remaining_bits == 0 + # no need to extend the sign bit + extracted_bits.append(msbs) - # if we expect more ciphertexts than we have in the message: we pad the result with zeros - if num_cts_msg != num_cts: - result_shape = native_int.shape + (num_cts,) - result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape) - padding = ctx.zeros( - ctx.tensor( - ctx.eint(msg_width + carry_width), - result_shape_msg[:-1] + (num_cts - num_cts_msg,), + # padding will contain the sign + padding_length = num_cts - len(extracted_bits) + if padding_length: + extend_sign_bit_table = [0, (2**msg_width) - 1] + extended_sign_bits = ctx.tlu(bits_type, sign_bits, extend_sign_bit_table) + padding_type = ctx.tensor( + ctx.eint(msg_width + carry_width), + result_shape_msg[:-1] + (padding_length,), + ) + padding = ctx.zeros(padding_type) + padding = ctx.add(padding_type, extended_sign_bits, padding) + to_concat = extracted_bits + [ + padding, + ] + else: + # no need to pad: we only extended the sign bit + to_concat = extracted_bits + else: + padding_length = num_cts - len(extracted_bits) + assert padding_length > 0 + padding = ctx.zeros( + ctx.tensor( + ctx.eint(msg_width + carry_width), + result_shape_msg[:-1] + (padding_length,), + ) ) - ) - result = ctx.concatenate(result_type, [result, padding], axis=-1) + to_concat = extracted_bits + [ + padding, + ] + else: + # no need to pad: we have the right number of ciphertexts + to_concat = extracted_bits + result_shape = native_int.shape + (num_cts,) + result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape) + result = ctx.concatenate( + result_type, + to_concat, + axis=-1, + ) return ctx.change_partition(result, dest_partition=dtype.params) # pylint: enable=missing-function-docstring,unused-argument diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index e08f3c2d9..c99eed7c0 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -5,6 +5,7 @@ Tests execution of tfhers conversion operations. import json import os import tempfile +from functools import partial from typing import List, Union import numpy as np @@ -68,8 +69,11 @@ def is_input_and_output_tfhers( return True +tfhers_int6_2_3 = partial(tfhers.TFHERSIntegerType, True, 6, 2, 3) + + @pytest.mark.parametrize( - "function, parameters, dtype", + "function, parameters, input_dtype, output_dtype", [ pytest.param( lambda x, y: x + y, @@ -78,6 +82,7 @@ def is_input_and_output_tfhers( "y": {"range": [0, 2**14], "status": "encrypted"}, }, tfhers.uint16_2_2, + tfhers.uint16_2_2, id="x + y", ), pytest.param( @@ -87,6 +92,7 @@ def is_input_and_output_tfhers( "y": {"range": [2**14, 2**15 - 1], "status": "encrypted"}, }, tfhers.uint16_2_2, + tfhers.uint16_2_2, id="x + y big values", ), pytest.param( @@ -96,6 +102,7 @@ def is_input_and_output_tfhers( "y": {"range": [0, 2**10], "status": "encrypted"}, }, tfhers.uint16_2_2, + tfhers.uint16_2_2, id="x - y", ), pytest.param( @@ -105,12 +112,78 @@ def is_input_and_output_tfhers( "y": {"range": [0, 2**3], "status": "encrypted"}, }, tfhers.uint8_2_2, + tfhers.uint8_2_2, id="x * y", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**6], "status": "encrypted"}, + "y": {"range": [0, 2**6], "status": "encrypted"}, + }, + tfhers.int8_2_2, + tfhers.int16_2_2, + id="signed x + y diff in/out", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -1], "status": "encrypted"}, + "y": {"range": [-(2**6), -1], "status": "encrypted"}, + }, + tfhers.int8_2_2, + tfhers.int16_2_2, + id="negative x + y diff in/out", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [0, 2**3], "status": "encrypted"}, + "y": {"range": [0, 2**3], "status": "encrypted"}, + }, + tfhers.uint8_2_2, + tfhers.uint16_2_2, + id="x * y diff in/out", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 0], "status": "encrypted"}, + "y": {"range": [0, 2**3], "status": "encrypted"}, + }, + tfhers.int8_2_2, + tfhers.int16_2_2, + id="negative x * y diff in/out", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [0, 2**3], "status": "encrypted"}, + "y": {"range": [0, 2**3], "status": "encrypted"}, + }, + tfhers.int8_2_2, + tfhers.int16_2_2, + id="signed x * y diff in/out", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-4, -4], "status": "encrypted"}, + "y": {"range": [3, 3], "status": "encrypted"}, + }, + tfhers_int6_2_3, + tfhers_int6_2_3, + # tfhers.int16_2_2, + id="sign extension without padding", + ), ], ) def test_tfhers_conversion_binary_encrypted( - function, parameters, dtype: tfhers.TFHERSIntegerType, helpers + function, + parameters, + input_dtype: tfhers.TFHERSIntegerType, + output_dtype: tfhers.TFHERSIntegerType, + helpers, ): """ Test different operations wrapped by tfhers conversion (2 tfhers inputs). @@ -122,22 +195,23 @@ def test_tfhers_conversion_binary_encrypted( if helpers.configuration().parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI: return - dtype = parameterize_partial_dtype(dtype) + input_dtype = parameterize_partial_dtype(input_dtype) + output_dtype = parameterize_partial_dtype(output_dtype) compiler = fhe.Compiler( - lambda x, y: binary_tfhers(x, y, function, dtype), + lambda x, y: binary_tfhers(x, y, function, output_dtype), parameter_encryption_statuses, ) inputset = [ - tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt) + tuple(tfhers.TFHERSInteger(input_dtype, arg) for arg in inpt) for inpt in helpers.generate_inputset(parameters) ] circuit = compiler.compile(inputset, helpers.configuration()) assert is_input_and_output_tfhers( circuit, - dtype.params.polynomial_size, + input_dtype.params.polynomial_size, [0, 1], [ 0, @@ -145,10 +219,10 @@ def test_tfhers_conversion_binary_encrypted( ) sample = helpers.generate_sample(parameters) - encoded_sample = (dtype.encode(v) for v in sample) + encoded_sample = (input_dtype.encode(v) for v in sample) encoded_result = circuit.encrypt_run_decrypt(*encoded_sample) - assert (dtype.decode(encoded_result) == function(*sample)).all() + assert (output_dtype.decode(encoded_result) == function(*sample)).all() @pytest.mark.parametrize(