mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 15:18:33 -05:00
fix: upcasting of signed integer when block decomposing
Some parts of the code did not use the correct way to decompose a clear integer into blocks which could be encrypted or used in scalar ops. The sign extension was not always properly done, leading for example in the encryption of a negative integer stored on a i8 to a SignedRadixCiphertext with a num_blocks greater than i8 to be incorrect: ``` let ct = cks.encrypt_signed(-1i8, 16) // 2_2 parameters let d: i32 = cks.decrypt_signed(&ct); assert_eq!(d, i32::from(-1i8)); // Fails ``` To fix, a BlockDecomposer::with_block_count function is added and used This function will properly do the sign extension when needed
This commit is contained in:
@@ -120,6 +120,22 @@ where
|
||||
Self::new_(value, bits_per_block, None, Some(padding_bit))
|
||||
}
|
||||
|
||||
/// Creates a block decomposer that will return `block_count` blocks
|
||||
///
|
||||
/// * If T is signed, extra block will be sign extended
|
||||
pub fn with_block_count(value: T, bits_per_block: u32, block_count: usize) -> Self {
|
||||
let mut decomposer = Self::new(value, bits_per_block);
|
||||
let block_count: u32 = block_count.try_into().unwrap();
|
||||
// If the new number of bits is less than the actual number of bits, it means
|
||||
// data will be truncated
|
||||
//
|
||||
// If the new number of bits is greater than the actual number of bits, it means
|
||||
// the right shift used internally will correctly sign extend for us
|
||||
let num_bits_valid = block_count * bits_per_block;
|
||||
decomposer.num_bits_valid = num_bits_valid;
|
||||
decomposer
|
||||
}
|
||||
|
||||
pub fn new(value: T, bits_per_block: u32) -> Self {
|
||||
Self::new_(value, bits_per_block, None, None)
|
||||
}
|
||||
@@ -245,7 +261,8 @@ where
|
||||
T: Recomposable,
|
||||
{
|
||||
pub fn value(&self) -> T {
|
||||
if self.bit_pos >= T::BITS as u32 {
|
||||
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
|
||||
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
|
||||
self.data
|
||||
} else {
|
||||
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
|
||||
@@ -359,6 +376,49 @@ mod tests {
|
||||
assert_eq!(expected_blocks, blocks);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bit_block_decomposer_with_block_count() {
|
||||
let bits_per_block = 3;
|
||||
let expected_blocks = [0, 0, 6, 7, 7, 7, 7, 7, 7];
|
||||
let value = i8::MIN;
|
||||
for block_count in 1..expected_blocks.len() {
|
||||
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(expected_blocks[..block_count], blocks);
|
||||
}
|
||||
|
||||
let bits_per_block = 3;
|
||||
let expected_blocks = [7, 7, 1, 0, 0, 0, 0, 0, 0];
|
||||
let value = i8::MAX;
|
||||
for block_count in 1..expected_blocks.len() {
|
||||
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(expected_blocks[..block_count], blocks);
|
||||
}
|
||||
|
||||
let bits_per_block = 2;
|
||||
let expected_blocks = [0, 0, 0, 2, 3, 3, 3, 3, 3];
|
||||
let value = i8::MIN;
|
||||
for block_count in 1..expected_blocks.len() {
|
||||
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(expected_blocks[..block_count], blocks);
|
||||
}
|
||||
|
||||
let bits_per_block = 2;
|
||||
let expected_blocks = [3, 3, 3, 1, 0, 0, 0, 0, 0, 0];
|
||||
let value = i8::MAX;
|
||||
for block_count in 1..expected_blocks.len() {
|
||||
let blocks = BlockDecomposer::with_block_count(value, bits_per_block, block_count)
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(expected_blocks[..block_count], blocks);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
|
||||
let value = u16::MAX as u32;
|
||||
|
||||
@@ -92,9 +92,7 @@ where
|
||||
// We need to concretize the iterator type to be able to pass callbacks consuming the iterator,
|
||||
// having an opaque return impl Iterator does not allow to take callbacks at this moment, not sure
|
||||
// the Fn(impl Trait) syntax can be made to work nicely with the rest of the language
|
||||
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Take<
|
||||
std::iter::Chain<std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>, std::iter::Repeat<u64>>,
|
||||
>;
|
||||
pub(crate) type ClearRadixBlockIterator<T> = std::iter::Map<BlockDecomposer<T>, fn(T) -> u64>;
|
||||
|
||||
pub(crate) fn create_clear_radix_block_iterator<T>(
|
||||
message: T,
|
||||
@@ -105,12 +103,7 @@ where
|
||||
T: DecomposableInto<u64>,
|
||||
{
|
||||
let bits_in_block = message_modulus.0.ilog2();
|
||||
let decomposer = BlockDecomposer::new(message, bits_in_block);
|
||||
|
||||
decomposer
|
||||
.iter_as::<u64>()
|
||||
.chain(std::iter::repeat(0u64))
|
||||
.take(num_blocks)
|
||||
BlockDecomposer::with_block_count(message, bits_in_block, num_blocks).iter_as::<u64>()
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_crt<BlockKey, Block, CrtCiphertextType, F>(
|
||||
|
||||
@@ -187,10 +187,9 @@ impl CudaServerKey {
|
||||
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
|
||||
};
|
||||
|
||||
let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.chain(std::iter::repeat(0))
|
||||
.take(num_blocks);
|
||||
let decomposer =
|
||||
BlockDecomposer::with_block_count(scalar, self.message_modulus.0.ilog2(), num_blocks)
|
||||
.iter_as::<u64>();
|
||||
let mut cpu_lwe_list = LweCiphertextList::new(
|
||||
0,
|
||||
lwe_size,
|
||||
|
||||
@@ -48,6 +48,7 @@ create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits);
|
||||
create_parameterized_test_classical_params!(integer_encrypt_decrypt_128_bits_specific_values);
|
||||
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits_specific_values);
|
||||
create_parameterized_test_classical_params!(integer_encrypt_decrypt_256_bits);
|
||||
create_parameterized_test_classical_params!(integer_encrypt_auto_cast);
|
||||
create_parameterized_test_classical_params!(integer_unchecked_add);
|
||||
create_parameterized_test_classical_params!(integer_smart_add);
|
||||
create_parameterized_test!(
|
||||
@@ -157,7 +158,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
for _ in 0..10 {
|
||||
let clear = rng.gen::<u128>();
|
||||
|
||||
@@ -172,7 +173,7 @@ fn integer_encrypt_decrypt_128_bits(param: ClassicPBSParameters) {
|
||||
fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
{
|
||||
let a = u64::MAX as u128;
|
||||
let ct = cks.encrypt_radix(a, num_block);
|
||||
@@ -220,7 +221,7 @@ fn integer_encrypt_decrypt_128_bits_specific_values(param: ClassicPBSParameters)
|
||||
fn integer_encrypt_decrypt_256_bits_specific_values(param: ClassicPBSParameters) {
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
{
|
||||
let a = (u64::MAX as u128) << 64;
|
||||
let b = 0;
|
||||
@@ -245,7 +246,7 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
let num_block = 256u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
|
||||
for _ in 0..10 {
|
||||
let clear0 = rng.gen::<u128>();
|
||||
@@ -261,11 +262,52 @@ fn integer_encrypt_decrypt_256_bits(param: ClassicPBSParameters) {
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_encrypt_auto_cast(param: ClassicPBSParameters) {
|
||||
// The goal is to test that encrypting a value stored in a type
|
||||
// for which the bit count does not match the target block count of the encrypted
|
||||
// radix properly applies upcasting/downcasting
|
||||
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let num_blocks = 32u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
|
||||
// Positive signed value
|
||||
let value = rng.gen_range(0..=i32::MAX);
|
||||
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
|
||||
let d: i64 = cks.decrypt_signed_radix(&ct);
|
||||
assert_eq!(i64::from(value), d);
|
||||
|
||||
let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
|
||||
let d: i16 = cks.decrypt_signed_radix(&ct);
|
||||
assert_eq!(value as i16, d);
|
||||
|
||||
// Negative signed value
|
||||
let value = rng.gen_range(i8::MIN..0);
|
||||
let ct = cks.encrypt_signed_radix(value, num_blocks * 2);
|
||||
let d: i64 = cks.decrypt_signed_radix(&ct);
|
||||
assert_eq!(i64::from(value), d);
|
||||
|
||||
let ct = cks.encrypt_signed_radix(value, num_blocks.div_ceil(2));
|
||||
let d: i16 = cks.decrypt_signed_radix(&ct);
|
||||
assert_eq!(value as i16, d);
|
||||
|
||||
// Unsigned value
|
||||
let value = rng.gen::<u32>();
|
||||
let ct = cks.encrypt_radix(value, num_blocks * 2);
|
||||
let d: u64 = cks.decrypt_radix(&ct);
|
||||
assert_eq!(u64::from(value), d);
|
||||
|
||||
let ct = cks.encrypt_radix(value, num_blocks.div_ceil(2));
|
||||
let d: u16 = cks.decrypt_radix(&ct);
|
||||
assert_eq!(value as u16, d);
|
||||
}
|
||||
|
||||
fn integer_smart_add_128_bits(param: ClassicPBSParameters) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
let num_block = 128u32.div_ceil(param.message_modulus.0.ilog2()) as usize;
|
||||
|
||||
for _ in 0..100 {
|
||||
let clear_0 = rng.gen::<u128>();
|
||||
|
||||
@@ -276,15 +276,13 @@ impl ServerKey {
|
||||
self.full_propagate_parallelized(ct);
|
||||
}
|
||||
|
||||
let scalar_blocks = BlockDecomposer::new(scalar, self.message_modulus().0.ilog2())
|
||||
.iter_as::<u8>()
|
||||
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
|
||||
(self.message_modulus().0 - 1) as u8
|
||||
} else {
|
||||
0
|
||||
}))
|
||||
.take(ct.blocks().len())
|
||||
.collect();
|
||||
let scalar_blocks = BlockDecomposer::with_block_count(
|
||||
scalar,
|
||||
self.message_modulus().0.ilog2(),
|
||||
ct.blocks().len(),
|
||||
)
|
||||
.iter_as::<u8>()
|
||||
.collect();
|
||||
|
||||
const COMPUTE_OVERFLOW: bool = false;
|
||||
const INPUT_CARRY: bool = false;
|
||||
|
||||
@@ -757,11 +757,8 @@ impl ServerKey {
|
||||
.map(|chunk_of_two| self.pack_block_chunk(chunk_of_two))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO);
|
||||
let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2())
|
||||
let mut b_blocks = BlockDecomposer::with_block_count(b, packed_modulus.ilog2(), a.len())
|
||||
.iter_as::<u64>()
|
||||
.chain(std::iter::repeat(padding_value))
|
||||
.take(a.len())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !num_block_is_even && b < Scalar::ZERO {
|
||||
@@ -1058,25 +1055,23 @@ impl ServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs);
|
||||
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
|
||||
0u64
|
||||
} else {
|
||||
self.message_modulus().0 - 1
|
||||
}))
|
||||
.take(lhs.blocks().len())
|
||||
.map(|scalar_block| {
|
||||
self.key
|
||||
.generate_lookup_table_bivariate(|is_superior, block| {
|
||||
if is_superior == 1 {
|
||||
block
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let luts = BlockDecomposer::with_block_count(
|
||||
rhs,
|
||||
self.message_modulus().0.ilog2(),
|
||||
lhs.blocks().len(),
|
||||
)
|
||||
.iter_as::<u64>()
|
||||
.map(|scalar_block| {
|
||||
self.key
|
||||
.generate_lookup_table_bivariate(|is_superior, block| {
|
||||
if is_superior == 1 {
|
||||
block
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let new_blocks = lhs
|
||||
.blocks()
|
||||
@@ -1097,25 +1092,23 @@ impl ServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs);
|
||||
let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.chain(std::iter::repeat(if rhs >= Scalar::ZERO {
|
||||
0u64
|
||||
} else {
|
||||
self.message_modulus().0 - 1
|
||||
}))
|
||||
.take(lhs.blocks().len())
|
||||
.map(|scalar_block| {
|
||||
self.key
|
||||
.generate_lookup_table_bivariate(|is_inferior, block| {
|
||||
if is_inferior == 1 {
|
||||
block
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let luts = BlockDecomposer::with_block_count(
|
||||
rhs,
|
||||
self.message_modulus().0.ilog2(),
|
||||
lhs.blocks().len(),
|
||||
)
|
||||
.iter_as::<u64>()
|
||||
.map(|scalar_block| {
|
||||
self.key
|
||||
.generate_lookup_table_bivariate(|is_inferior, block| {
|
||||
if is_inferior == 1 {
|
||||
block
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let new_blocks = lhs
|
||||
.blocks()
|
||||
|
||||
@@ -640,16 +640,13 @@ impl ServerKey {
|
||||
|
||||
const INPUT_CARRY: bool = true;
|
||||
let flipped_scalar = !scalar;
|
||||
let decomposed_flipped_scalar =
|
||||
BlockDecomposer::new(flipped_scalar, self.message_modulus().0.ilog2())
|
||||
.iter_as::<u8>()
|
||||
.chain(std::iter::repeat(if scalar < Scalar::ZERO {
|
||||
0
|
||||
} else {
|
||||
(self.message_modulus().0 - 1) as u8
|
||||
}))
|
||||
.take(lhs.blocks.len())
|
||||
.collect::<Vec<_>>();
|
||||
let decomposed_flipped_scalar = BlockDecomposer::with_block_count(
|
||||
flipped_scalar,
|
||||
self.message_modulus().0.ilog2(),
|
||||
lhs.blocks.len(),
|
||||
)
|
||||
.iter_as::<u8>()
|
||||
.collect::<Vec<_>>();
|
||||
let maybe_overflow = self.add_assign_scalar_blocks_parallelized(
|
||||
lhs,
|
||||
decomposed_flipped_scalar,
|
||||
|
||||
Reference in New Issue
Block a user