mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-08 20:38:06 -05:00
feat(frontend): support signed integers with arbitrary bitwidth output
This commit is contained in:
@@ -1018,23 +1018,23 @@ class Converter:
|
||||
num_cts = output_tfhers_bit_width // msg_width
|
||||
# number of ciphertexts representing the actual part of the message
|
||||
num_full_cts_msg = native_int.bit_width // msg_width
|
||||
num_half_cts_msg = 0 if native_int.bit_width % msg_width == 0 else 1
|
||||
remaining_bits = native_int.bit_width % msg_width
|
||||
num_half_cts_msg = 0 if remaining_bits == 0 else 1
|
||||
num_cts_msg = num_full_cts_msg + num_half_cts_msg
|
||||
|
||||
# adds a dimension of ciphertexts for the result
|
||||
result_shape_msg = native_int.shape + (num_cts_msg,)
|
||||
result_type_msg = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape_msg)
|
||||
|
||||
# we reshape so that we can concatenate later over the last dim (ciphertext dim)
|
||||
reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,))
|
||||
|
||||
# we want to extract `msg_width` bits at a time, and store them
|
||||
# in a `msg_width + carry_width` bits eint
|
||||
bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
|
||||
bits_type = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
|
||||
# we extract lsb first
|
||||
extracted_bits = [
|
||||
ctx.extract_bits(
|
||||
bits_shape,
|
||||
bits_type,
|
||||
reshaped_native_int,
|
||||
bits=slice(i * msg_width, (i + 1) * msg_width, 1),
|
||||
)
|
||||
@@ -1042,32 +1042,97 @@ class Converter:
|
||||
]
|
||||
# we need to extract the remaining bits if any
|
||||
if num_half_cts_msg:
|
||||
# in the unsigned case, we won't do anything to this half full ciphertext: so we target
|
||||
# the final bitwidth directly
|
||||
target_bitwidth = remaining_bits if dtype.is_signed else (msg_width + carry_width)
|
||||
remaining_bits_type = ctx.tensor(
|
||||
ctx.eint(target_bitwidth),
|
||||
reshaped_native_int.shape,
|
||||
)
|
||||
extracted_bits.append(
|
||||
ctx.extract_bits(
|
||||
bits_shape,
|
||||
remaining_bits_type,
|
||||
reshaped_native_int,
|
||||
bits=slice(
|
||||
num_full_cts_msg * msg_width,
|
||||
num_full_cts_msg * msg_width + native_int.bit_width % msg_width,
|
||||
num_full_cts_msg * msg_width + remaining_bits,
|
||||
1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result = ctx.concatenate(result_type_msg, extracted_bits, axis=-1)
|
||||
# if we expect more ciphertexts than we have in the message: we pad the result
|
||||
# if we have one non-full ct (when signed): we extend the sign bit
|
||||
if num_cts_msg != num_cts or (num_half_cts_msg and dtype.is_signed):
|
||||
# if it's signed then we have to move the sign bit to the msb
|
||||
if dtype.is_signed:
|
||||
# special case where the sign_bit is already isolated into a single ciphertext
|
||||
if remaining_bits == 1:
|
||||
sign_bits = extracted_bits.pop()
|
||||
else: # remaining_bits either > 1 or == 0
|
||||
msbs = extracted_bits.pop()
|
||||
sign_bit_idx = (remaining_bits - 1) % msg_width
|
||||
sign_bits = ctx.extract_bits(
|
||||
ctx.tensor(ctx.eint(1), reshaped_native_int.shape),
|
||||
msbs,
|
||||
bits=sign_bit_idx,
|
||||
)
|
||||
if remaining_bits > 1:
|
||||
# extend sign bit in msbs
|
||||
# For example: we have this message |0000|0110| where sign_bit_idx is 2
|
||||
# we want the result to be |0000|1110| extending the sign bit (1) to the
|
||||
# remaining MSBs
|
||||
sign_extension = 2**msg_width - 1 - (2 ** (sign_bit_idx + 1) - 1)
|
||||
extend_sign_bit_table = [
|
||||
i + sign_extension if i & 2**sign_bit_idx else i
|
||||
for i in range(2**remaining_bits)
|
||||
]
|
||||
msbs_with_extended_sign = ctx.tlu(bits_type, msbs, extend_sign_bit_table)
|
||||
extracted_bits.append(msbs_with_extended_sign)
|
||||
else: # remaining_bits == 0
|
||||
# no need to extend the sign bit
|
||||
extracted_bits.append(msbs)
|
||||
|
||||
# if we expect more ciphertexts than we have in the message: we pad the result with zeros
|
||||
if num_cts_msg != num_cts:
|
||||
result_shape = native_int.shape + (num_cts,)
|
||||
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
|
||||
padding = ctx.zeros(
|
||||
ctx.tensor(
|
||||
ctx.eint(msg_width + carry_width),
|
||||
result_shape_msg[:-1] + (num_cts - num_cts_msg,),
|
||||
# padding will contain the sign
|
||||
padding_length = num_cts - len(extracted_bits)
|
||||
if padding_length:
|
||||
extend_sign_bit_table = [0, (2**msg_width) - 1]
|
||||
extended_sign_bits = ctx.tlu(bits_type, sign_bits, extend_sign_bit_table)
|
||||
padding_type = ctx.tensor(
|
||||
ctx.eint(msg_width + carry_width),
|
||||
result_shape_msg[:-1] + (padding_length,),
|
||||
)
|
||||
padding = ctx.zeros(padding_type)
|
||||
padding = ctx.add(padding_type, extended_sign_bits, padding)
|
||||
to_concat = extracted_bits + [
|
||||
padding,
|
||||
]
|
||||
else:
|
||||
# no need to pad: we only extended the sign bit
|
||||
to_concat = extracted_bits
|
||||
else:
|
||||
padding_length = num_cts - len(extracted_bits)
|
||||
assert padding_length > 0
|
||||
padding = ctx.zeros(
|
||||
ctx.tensor(
|
||||
ctx.eint(msg_width + carry_width),
|
||||
result_shape_msg[:-1] + (padding_length,),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = ctx.concatenate(result_type, [result, padding], axis=-1)
|
||||
to_concat = extracted_bits + [
|
||||
padding,
|
||||
]
|
||||
else:
|
||||
# no need to pad: we have the right number of ciphertexts
|
||||
to_concat = extracted_bits
|
||||
|
||||
result_shape = native_int.shape + (num_cts,)
|
||||
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
|
||||
result = ctx.concatenate(
|
||||
result_type,
|
||||
to_concat,
|
||||
axis=-1,
|
||||
)
|
||||
return ctx.change_partition(result, dest_partition=dtype.params)
|
||||
|
||||
# pylint: enable=missing-function-docstring,unused-argument
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests execution of tfhers conversion operations.
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -68,8 +69,11 @@ def is_input_and_output_tfhers(
|
||||
return True
|
||||
|
||||
|
||||
tfhers_int6_2_3 = partial(tfhers.TFHERSIntegerType, True, 6, 2, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function, parameters, dtype",
|
||||
"function, parameters, input_dtype, output_dtype",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
@@ -78,6 +82,7 @@ def is_input_and_output_tfhers(
|
||||
"y": {"range": [0, 2**14], "status": "encrypted"},
|
||||
},
|
||||
tfhers.uint16_2_2,
|
||||
tfhers.uint16_2_2,
|
||||
id="x + y",
|
||||
),
|
||||
pytest.param(
|
||||
@@ -87,6 +92,7 @@ def is_input_and_output_tfhers(
|
||||
"y": {"range": [2**14, 2**15 - 1], "status": "encrypted"},
|
||||
},
|
||||
tfhers.uint16_2_2,
|
||||
tfhers.uint16_2_2,
|
||||
id="x + y big values",
|
||||
),
|
||||
pytest.param(
|
||||
@@ -96,6 +102,7 @@ def is_input_and_output_tfhers(
|
||||
"y": {"range": [0, 2**10], "status": "encrypted"},
|
||||
},
|
||||
tfhers.uint16_2_2,
|
||||
tfhers.uint16_2_2,
|
||||
id="x - y",
|
||||
),
|
||||
pytest.param(
|
||||
@@ -105,12 +112,78 @@ def is_input_and_output_tfhers(
|
||||
"y": {"range": [0, 2**3], "status": "encrypted"},
|
||||
},
|
||||
tfhers.uint8_2_2,
|
||||
tfhers.uint8_2_2,
|
||||
id="x * y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": {"range": [0, 2**6], "status": "encrypted"},
|
||||
"y": {"range": [0, 2**6], "status": "encrypted"},
|
||||
},
|
||||
tfhers.int8_2_2,
|
||||
tfhers.int16_2_2,
|
||||
id="signed x + y diff in/out",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": {"range": [-(2**6), -1], "status": "encrypted"},
|
||||
"y": {"range": [-(2**6), -1], "status": "encrypted"},
|
||||
},
|
||||
tfhers.int8_2_2,
|
||||
tfhers.int16_2_2,
|
||||
id="negative x + y diff in/out",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 2**3], "status": "encrypted"},
|
||||
"y": {"range": [0, 2**3], "status": "encrypted"},
|
||||
},
|
||||
tfhers.uint8_2_2,
|
||||
tfhers.uint16_2_2,
|
||||
id="x * y diff in/out",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [-(2**3), 0], "status": "encrypted"},
|
||||
"y": {"range": [0, 2**3], "status": "encrypted"},
|
||||
},
|
||||
tfhers.int8_2_2,
|
||||
tfhers.int16_2_2,
|
||||
id="negative x * y diff in/out",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [0, 2**3], "status": "encrypted"},
|
||||
"y": {"range": [0, 2**3], "status": "encrypted"},
|
||||
},
|
||||
tfhers.int8_2_2,
|
||||
tfhers.int16_2_2,
|
||||
id="signed x * y diff in/out",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": {"range": [-4, -4], "status": "encrypted"},
|
||||
"y": {"range": [3, 3], "status": "encrypted"},
|
||||
},
|
||||
tfhers_int6_2_3,
|
||||
tfhers_int6_2_3,
|
||||
# tfhers.int16_2_2,
|
||||
id="sign extension without padding",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tfhers_conversion_binary_encrypted(
|
||||
function, parameters, dtype: tfhers.TFHERSIntegerType, helpers
|
||||
function,
|
||||
parameters,
|
||||
input_dtype: tfhers.TFHERSIntegerType,
|
||||
output_dtype: tfhers.TFHERSIntegerType,
|
||||
helpers,
|
||||
):
|
||||
"""
|
||||
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
|
||||
@@ -122,22 +195,23 @@ def test_tfhers_conversion_binary_encrypted(
|
||||
if helpers.configuration().parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI:
|
||||
return
|
||||
|
||||
dtype = parameterize_partial_dtype(dtype)
|
||||
input_dtype = parameterize_partial_dtype(input_dtype)
|
||||
output_dtype = parameterize_partial_dtype(output_dtype)
|
||||
|
||||
compiler = fhe.Compiler(
|
||||
lambda x, y: binary_tfhers(x, y, function, dtype),
|
||||
lambda x, y: binary_tfhers(x, y, function, output_dtype),
|
||||
parameter_encryption_statuses,
|
||||
)
|
||||
|
||||
inputset = [
|
||||
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
|
||||
tuple(tfhers.TFHERSInteger(input_dtype, arg) for arg in inpt)
|
||||
for inpt in helpers.generate_inputset(parameters)
|
||||
]
|
||||
circuit = compiler.compile(inputset, helpers.configuration())
|
||||
|
||||
assert is_input_and_output_tfhers(
|
||||
circuit,
|
||||
dtype.params.polynomial_size,
|
||||
input_dtype.params.polynomial_size,
|
||||
[0, 1],
|
||||
[
|
||||
0,
|
||||
@@ -145,10 +219,10 @@ def test_tfhers_conversion_binary_encrypted(
|
||||
)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
encoded_sample = (dtype.encode(v) for v in sample)
|
||||
encoded_sample = (input_dtype.encode(v) for v in sample)
|
||||
encoded_result = circuit.encrypt_run_decrypt(*encoded_sample)
|
||||
|
||||
assert (dtype.decode(encoded_result) == function(*sample)).all()
|
||||
assert (output_dtype.decode(encoded_result) == function(*sample)).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user