feat(frontend): support signed integers with arbitrary bitwidth output

This commit is contained in:
youben11
2025-02-20 11:30:12 +01:00
parent f776e32d24
commit ff62daaf57
2 changed files with 164 additions and 25 deletions

View File

@@ -1018,23 +1018,23 @@ class Converter:
num_cts = output_tfhers_bit_width // msg_width
# number of ciphertexts representing the actual part of the message
num_full_cts_msg = native_int.bit_width // msg_width
num_half_cts_msg = 0 if native_int.bit_width % msg_width == 0 else 1
remaining_bits = native_int.bit_width % msg_width
num_half_cts_msg = 0 if remaining_bits == 0 else 1
num_cts_msg = num_full_cts_msg + num_half_cts_msg
# adds a dimension of ciphertexts for the result
result_shape_msg = native_int.shape + (num_cts_msg,)
result_type_msg = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape_msg)
# we reshape so that we can concatenate later over the last dim (ciphertext dim)
reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,))
# we want to extract `msg_width` bits at a time, and store them
# in a `msg_width + carry_width` bits eint
bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
bits_type = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
# we extract lsb first
extracted_bits = [
ctx.extract_bits(
bits_shape,
bits_type,
reshaped_native_int,
bits=slice(i * msg_width, (i + 1) * msg_width, 1),
)
@@ -1042,32 +1042,97 @@ class Converter:
]
# we need to extract the remaining bits if any
if num_half_cts_msg:
# in the unsigned case, we won't do anything to this half full ciphertext: so we target
# the final bitwidth directly
target_bitwidth = remaining_bits if dtype.is_signed else (msg_width + carry_width)
remaining_bits_type = ctx.tensor(
ctx.eint(target_bitwidth),
reshaped_native_int.shape,
)
extracted_bits.append(
ctx.extract_bits(
bits_shape,
remaining_bits_type,
reshaped_native_int,
bits=slice(
num_full_cts_msg * msg_width,
num_full_cts_msg * msg_width + native_int.bit_width % msg_width,
num_full_cts_msg * msg_width + remaining_bits,
1,
),
)
)
result = ctx.concatenate(result_type_msg, extracted_bits, axis=-1)
# if we expect more ciphertexts than we have in the message: we pad the result
# if we have one non-full ct (when signed): we extend the sign bit
if num_cts_msg != num_cts or (num_half_cts_msg and dtype.is_signed):
# if it's signed then we have to move the sign bit to the msb
if dtype.is_signed:
# special case where the sign_bit is already isolated into a single ciphertext
if remaining_bits == 1:
sign_bits = extracted_bits.pop()
else: # remaining_bits either > 1 or == 0
msbs = extracted_bits.pop()
sign_bit_idx = (remaining_bits - 1) % msg_width
sign_bits = ctx.extract_bits(
ctx.tensor(ctx.eint(1), reshaped_native_int.shape),
msbs,
bits=sign_bit_idx,
)
if remaining_bits > 1:
# extend sign bit in msbs
# For example: we have this message |0000|0110| where sign_bit_idx is 2
# we want the result to be |0000|1110| extending the sign bit (1) to the
# remaining MSBs
sign_extension = 2**msg_width - 1 - (2 ** (sign_bit_idx + 1) - 1)
extend_sign_bit_table = [
i + sign_extension if i & 2**sign_bit_idx else i
for i in range(2**remaining_bits)
]
msbs_with_extended_sign = ctx.tlu(bits_type, msbs, extend_sign_bit_table)
extracted_bits.append(msbs_with_extended_sign)
else: # remaining_bits == 0
# no need to extend the sign bit
extracted_bits.append(msbs)
# if we expect more ciphertexts than we have in the message: we pad the result with zeros
if num_cts_msg != num_cts:
result_shape = native_int.shape + (num_cts,)
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
padding = ctx.zeros(
ctx.tensor(
ctx.eint(msg_width + carry_width),
result_shape_msg[:-1] + (num_cts - num_cts_msg,),
# padding will contain the sign
padding_length = num_cts - len(extracted_bits)
if padding_length:
extend_sign_bit_table = [0, (2**msg_width) - 1]
extended_sign_bits = ctx.tlu(bits_type, sign_bits, extend_sign_bit_table)
padding_type = ctx.tensor(
ctx.eint(msg_width + carry_width),
result_shape_msg[:-1] + (padding_length,),
)
padding = ctx.zeros(padding_type)
padding = ctx.add(padding_type, extended_sign_bits, padding)
to_concat = extracted_bits + [
padding,
]
else:
# no need to pad: we only extended the sign bit
to_concat = extracted_bits
else:
padding_length = num_cts - len(extracted_bits)
assert padding_length > 0
padding = ctx.zeros(
ctx.tensor(
ctx.eint(msg_width + carry_width),
result_shape_msg[:-1] + (padding_length,),
)
)
)
result = ctx.concatenate(result_type, [result, padding], axis=-1)
to_concat = extracted_bits + [
padding,
]
else:
# no need to pad: we have the right number of ciphertexts
to_concat = extracted_bits
result_shape = native_int.shape + (num_cts,)
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
result = ctx.concatenate(
result_type,
to_concat,
axis=-1,
)
return ctx.change_partition(result, dest_partition=dtype.params)
# pylint: enable=missing-function-docstring,unused-argument

View File

@@ -5,6 +5,7 @@ Tests execution of tfhers conversion operations.
import json
import os
import tempfile
from functools import partial
from typing import List, Union
import numpy as np
@@ -68,8 +69,11 @@ def is_input_and_output_tfhers(
return True
tfhers_int6_2_3 = partial(tfhers.TFHERSIntegerType, True, 6, 2, 3)
@pytest.mark.parametrize(
"function, parameters, dtype",
"function, parameters, input_dtype, output_dtype",
[
pytest.param(
lambda x, y: x + y,
@@ -78,6 +82,7 @@ def is_input_and_output_tfhers(
"y": {"range": [0, 2**14], "status": "encrypted"},
},
tfhers.uint16_2_2,
tfhers.uint16_2_2,
id="x + y",
),
pytest.param(
@@ -87,6 +92,7 @@ def is_input_and_output_tfhers(
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
},
tfhers.uint16_2_2,
tfhers.uint16_2_2,
id="x + y big values",
),
pytest.param(
@@ -96,6 +102,7 @@ def is_input_and_output_tfhers(
"y": {"range": [0, 2**10], "status": "encrypted"},
},
tfhers.uint16_2_2,
tfhers.uint16_2_2,
id="x - y",
),
pytest.param(
@@ -105,12 +112,78 @@ def is_input_and_output_tfhers(
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
tfhers.uint8_2_2,
id="x * y",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [0, 2**6], "status": "encrypted"},
"y": {"range": [0, 2**6], "status": "encrypted"},
},
tfhers.int8_2_2,
tfhers.int16_2_2,
id="signed x + y diff in/out",
),
pytest.param(
lambda x, y: x + y,
{
"x": {"range": [-(2**6), -1], "status": "encrypted"},
"y": {"range": [-(2**6), -1], "status": "encrypted"},
},
tfhers.int8_2_2,
tfhers.int16_2_2,
id="negative x + y diff in/out",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.uint8_2_2,
tfhers.uint16_2_2,
id="x * y diff in/out",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [-(2**3), 0], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.int8_2_2,
tfhers.int16_2_2,
id="negative x * y diff in/out",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [0, 2**3], "status": "encrypted"},
},
tfhers.int8_2_2,
tfhers.int16_2_2,
id="signed x * y diff in/out",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [-4, -4], "status": "encrypted"},
"y": {"range": [3, 3], "status": "encrypted"},
},
tfhers_int6_2_3,
tfhers_int6_2_3,
# tfhers.int16_2_2,
id="sign extension without padding",
),
],
)
def test_tfhers_conversion_binary_encrypted(
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
function,
parameters,
input_dtype: tfhers.TFHERSIntegerType,
output_dtype: tfhers.TFHERSIntegerType,
helpers,
):
"""
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
@@ -122,22 +195,23 @@ def test_tfhers_conversion_binary_encrypted(
if helpers.configuration().parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI:
return
dtype = parameterize_partial_dtype(dtype)
input_dtype = parameterize_partial_dtype(input_dtype)
output_dtype = parameterize_partial_dtype(output_dtype)
compiler = fhe.Compiler(
lambda x, y: binary_tfhers(x, y, function, dtype),
lambda x, y: binary_tfhers(x, y, function, output_dtype),
parameter_encryption_statuses,
)
inputset = [
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
tuple(tfhers.TFHERSInteger(input_dtype, arg) for arg in inpt)
for inpt in helpers.generate_inputset(parameters)
]
circuit = compiler.compile(inputset, helpers.configuration())
assert is_input_and_output_tfhers(
circuit,
dtype.params.polynomial_size,
input_dtype.params.polynomial_size,
[0, 1],
[
0,
@@ -145,10 +219,10 @@ def test_tfhers_conversion_binary_encrypted(
)
sample = helpers.generate_sample(parameters)
encoded_sample = (dtype.encode(v) for v in sample)
encoded_sample = (input_dtype.encode(v) for v in sample)
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
assert (dtype.decode(encoded_result) == function(*sample)).all()
assert (output_dtype.decode(encoded_result) == function(*sample)).all()
@pytest.mark.parametrize(