fix: upcasting of signed integer when block decomposing

Some parts of the code did not use the correct way to
decompose a clear integer into blocks which could be encrypted
or used in scalar ops.

The sign extension was not always properly done, leading for example
in the encryption of a negative integer stored on a i8 to a
SignedRadixCiphertext with a num_blocks greater than i8 to be incorrect:

```
let ct = cks.encrypt_signed(-1i8, 16) // 2_2 parameters
let d: i32 = cks.decrypt_signed(&ct);
assert_eq!(d, i32::from(-1i8)); // Fails
```

To fix, a BlockDecomposer::with_block_count function is added and used
This function will properly do the sign extension when needed
This commit is contained in:
tmontaigu
2025-03-11 15:09:14 +01:00
parent 9886256242
commit 0fdda14495
7 changed files with 162 additions and 80 deletions

View File

@@ -120,6 +120,22 @@ where
Self::new_(value, bits_per_block, None, Some(padding_bit))
}
/// Creates a block decomposer that will return `block_count` blocks
///
/// * If T is signed, extra block will be sign extended
pub fn with_block_count(value: T, bits_per_block: u32, block_count: usize) -> Self {
let mut decomposer = Self::new(value, bits_per_block);
let block_count: u32 = block_count.try_into().unwrap();
// If the new number of bits is less than the actual number of bits, it means
// data will be truncated
//
// If the new number of bits is greater than the actual number of bits, it means
// the right shift used internally will correctly sign extend for us
let num_bits_valid = block_count * bits_per_block;
decomposer.num_bits_valid = num_bits_valid;
decomposer
}
pub fn new(value: T, bits_per_block: u32) -> Self {
Self::new_(value, bits_per_block, None, None)
}
@@ -245,7 +261,8 @@ where
T: Recomposable,
{
pub fn value(&self) -> T {
if self.bit_pos >= T::BITS as u32 {
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
self.data
} else {
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
@@ -359,6 +376,49 @@ mod tests {
assert_eq!(expected_blocks, blocks);
}
#[test]
fn test_bit_block_decomposer_with_block_count() {
let bits_per_block = 3;
let expected_blocks = [0, 0, 6, 7, 7, 7, 7, 7, 7];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count], blocks);
}
let bits_per_block = 3;
let expected_blocks = [7, 7, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count], blocks);
}
let bits_per_block = 2;
let expected_blocks = [0, 0, 0, 2, 3, 3, 3, 3, 3];
let value = i8::MIN;
for block_count in 1..expected_blocks.len() {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count], blocks);
}
let bits_per_block = 2;
let expected_blocks = [3, 3, 3, 1, 0, 0, 0, 0, 0, 0];
let value = i8::MAX;
for block_count in 1..expected_blocks.len() {
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
.iter_as::<u64>()
.collect::<Vec<_>>();
assert_eq!(expected_blocks[..block_count], blocks);
}
}
#[test]
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
let value = u16::MAX as u32;

View File

@@ -92,9 +92,7 @@ where
// We need to concretize the iterator type to be able to pass callbacks consuming the iterator,
// having an opaque return impl Iterator does not allow to take callbacks at this moment, not sure
// the Fn(impl Trait) syntax can be made to work nicely with the rest of the language
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Take<
std::iter::Chain<std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>, std::iter::Repeat<u64>>,
>;
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>;
pub(crate) fn create_clear_radix_block_iterator<T>(
message: T,
@@ -105,12 +103,7 @@ where
T: DecomposableInto<u64>,
{
let bits_in_block = message_modulus.0.ilog2();
let decomposer = BlockDecomposer::new(message, bits_in_block);
decomposer
.iter_as::<u64>()
.chain(std::iter::repeat(0u64))
.take(num_blocks)
BlockDecomposer::with_block_count(message, bits_in_block, num_blocks).iter_as::<u64>()
}
pub(crate) fn encrypt_crt<BlockKey, Block, CrtCiphertextType, F>(

View File

@@ -187,10 +187,9 @@ impl CudaServerKey {
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
};
let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(0))
.take(num_blocks);
let decomposer =
BlockDecomposer::with_block_count(scalar, self.message_modulus.0.ilog2(), num_blocks)
.iter_as::<u64>();
let mut cpu_lwe_list = LweCiphertextList::new(
0,
lwe_size,

View File

@@ -48,6 +48,7 @@ create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits_specific_values);
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits);
create_parameterized_test_classical_params!(integer_encrypt_auto_cast);
create_parameterized_test_classical_params!(integer_unchecked_add);
create_parameterized_test_classical_params!(integer_smart_add);
create_parameterized_test!(
@@ -157,7 +158,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
for _ in 0..10 {
let clear = rng.gen::<u128>();
@@ -172,7 +173,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = u64::MAX as u128;
let ct = cks.encrypt_radix(a, num_block);
@@ -220,7 +221,7 @@ fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters)
fn integer_encrypt_decrypt_256_bits_specific_values(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
{
let a = (u64::MAX as u128) << 64;
let b = 0;
@@ -245,7 +246,7 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
for _ in 0..10 {
let clear0 = rng.gen::<u128>();
@@ -261,11 +262,52 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
}
}
fn integer_encrypt_auto_cast(param: ClassicPBSParameters) {
// The goal is to test that encrypting a value stored in a type
// for which the bit count does not match the target block count of the encrypted
// radix properly applies upcasting/downcasting
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let num_blocks = 32u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
// Positive signed value
let value = rng.gen_range(0..=i32::MAX);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);
let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);
// Negative signed value
let value = rng.gen_range(i8::MIN..0);
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
let d: i64 = cks.decrypt_signed_radix(&ct);
assert_eq!(i64::from(value), d);
let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
let d: i16 = cks.decrypt_signed_radix(&ct);
assert_eq!(value as i16, d);
// Unsigned value
let value = rng.gen::<u32>();
let ct = cks.encrypt_radix(value, num_blocks * 2);
let d: u64 = cks.decrypt_radix(&ct);
assert_eq!(u64::from(value), d);
let ct = cks.encrypt_radix(value, num_blocks.div_ceil(2));
let d: u16 = cks.decrypt_radix(&ct);
assert_eq!(value as u16, d);
}
fn integer_smart_add_128_bits(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
for _ in 0..100 {
let clear_0 = rng.gen::<u128>();

View File

@@ -276,15 +276,13 @@ impl ServerKey {
self.full_propagate_parallelized(ct);
}
let scalar_blocks = BlockDecomposer::new(scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
(self.message_modulus().0 - 1) as u8
} else {
0
}))
.take(ct.blocks().len())
.collect();
let scalar_blocks = BlockDecomposer::with_block_count(
scalar,
self.message_modulus().0.ilog2(),
ct.blocks().len(),
)
.iter_as::<u8>()
.collect();
const COMPUTE_OVERFLOW: bool = false;
const INPUT_CARRY: bool = false;

View File

@@ -757,11 +757,8 @@ impl ServerKey {
.map(|chunk_of_two| self.pack_block_chunk(chunk_of_two))
.collect::<Vec<_>>();
let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO);
let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2())
let mut b_blocks = BlockDecomposer::with_block_count(b, packed_modulus.ilog2(), a.len())
.iter_as::<u64>()
.chain(std::iter::repeat(padding_value))
.take(a.len())
.collect::<Vec<_>>();
if !num_block_is_even && b < Scalar::ZERO {
@@ -1058,25 +1055,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len(),
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_superior, block| {
if is_superior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let new_blocks = lhs
.blocks()
@@ -1097,25 +1092,23 @@ impl ServerKey {
Scalar: DecomposableInto<u64>,
{
let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs);
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
0u64
} else {
self.message_modulus().0 - 1
}))
.take(lhs.blocks().len())
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let luts = BlockDecomposer::with_block_count(
rhs,
self.message_modulus().0.ilog2(),
lhs.blocks().len(),
)
.iter_as::<u64>()
.map(|scalar_block| {
self.key
.generate_lookup_table_bivariate(|is_inferior, block| {
if is_inferior == 1 {
block
} else {
scalar_block
}
})
})
.collect::<Vec<_>>();
let new_blocks = lhs
.blocks()

View File

@@ -640,16 +640,13 @@ impl ServerKey {
const INPUT_CARRY: bool = true;
let flipped_scalar = !scalar;
let decomposed_flipped_scalar =
BlockDecomposer::new(flipped_scalar, self.message_modulus().0.ilog2())
.iter_as::<u8>()
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
0
} else {
(self.message_modulus().0 - 1) as u8
}))
.take(lhs.blocks.len())
.collect::<Vec<_>>();
let decomposed_flipped_scalar = BlockDecomposer::with_block_count(
flipped_scalar,
self.message_modulus().0.ilog2(),
lhs.blocks.len(),
)
.iter_as::<u8>()
.collect::<Vec<_>>();
let maybe_overflow = self.add_assign_scalar_blocks_parallelized(
lhs,
decomposed_flipped_scalar,