feat(integer): add oprf over any range

This commit is contained in:
Mayeul@Zama
2025-09-01 18:48:47 +02:00
committed by mayeul-zama
parent aefec1fe64
commit 9b5596ca66
3 changed files with 242 additions and 45 deletions

View File

@@ -29,7 +29,7 @@ pub use uniform_binary::*;
pub use uniform_ternary::*;
#[cfg(test)]
mod tests;
pub(crate) mod tests;
mod gaussian;
mod generator;

View File

@@ -4,6 +4,8 @@ use crate::core_crypto::commons::math::random::{Distribution, RandomGenerable, T
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, Numeric, UnsignedInteger};
use crate::core_crypto::commons::test_tools::*;
use itertools::Itertools;
use std::ops::AddAssign;
fn test_normal_random_three_sigma<T: UnsignedTorus>() {
//! test if the normal random generation with std_dev is below 3*std_dev (99.7%)
@@ -432,14 +434,20 @@ fn dkw_cdf_bands_width(number_of_samples: usize, confidence_interval: f64) -> f6
// https://en.wikipedia.org/wiki/Dvoretzky%E2%80%93Kiefer%E2%80%93Wolfowitz_inequality#Building_CDF_bands
// the true CDF is between the empirical CDF +/- this band width with probability 1 - alpha
// Said otherwise, the abs diff should be less than that value with high probability
fn dkw_cdf_bands_width_formula(sample_size: f64, alpha: f64) -> f64 {
f64::sqrt(f64::ln(2.0 / alpha) / (2.0 * sample_size))
}
// alpha = 1 - probability of being in the interval
dkw_cdf_bands_width_formula(number_of_samples as f64, 1.0 - confidence_interval)
}
pub fn dkw_cdf_bands_width_formula(sample_size: f64, alpha: f64) -> f64 {
f64::sqrt(f64::ln(2.0 / alpha) / (2.0 * sample_size))
}
// https://en.wikipedia.org/wiki/Dvoretzky%E2%80%93Kiefer%E2%80%93Wolfowitz_inequality#The_DKW_inequality
pub fn dkw_alpha_from_epsilon(sample_size: f64, epsilon: f64) -> f64 {
2.0 * (-epsilon * epsilon * (2.0 * sample_size)).exp()
}
fn test_random_from_distribution_custom_mod<Scalar, D>(
creation_infos: D::CreationInfos,
ciphertext_modulus: CiphertextModulus<Scalar>,
@@ -490,43 +498,22 @@ fn test_random_from_distribution_custom_mod<Scalar, D>(
}
}
let mut cumulative_sums = vec![0u64; distinct_values];
let cumulative_bins = cumulate(&bins);
let mut curr_sum = 0;
// Compute the cumulative sums
for (bin_count, cum_sum) in bins.iter().zip(cumulative_sums.iter_mut()) {
curr_sum += bin_count;
*cum_sum = curr_sum;
}
let theoretical_cdf: Vec<f64> = (0..distinct_values)
.map(|bin_idx| {
let integer_value: Scalar =
distribution.map_usize_to_value(bin_idx, ciphertext_modulus);
// CDF for the uniform distribution
distribution.cumulative_distribution_function(integer_value, ciphertext_modulus)
})
.collect();
// Inaccurate if modulus >~ 2^53 / number_of_samples_per_bin, but if that's the case your
// memory most likely blew up before (or the universe died its heat death)
let number_of_samples = NUMBER_OF_SAMPLES_PER_VALUE * distinct_values;
let sup_diff: f64 = cumulative_sums
.iter()
.copied()
.enumerate()
.map(|(bin_idx, x)| {
// Compute the observed CDF
let empirical_cdf = x as f64 / number_of_samples as f64;
let integer_value: Scalar =
distribution.map_usize_to_value(bin_idx, ciphertext_modulus);
// CDF for the uniform distribution
let theoretical_cdf = distribution
.cumulative_distribution_function(integer_value, ciphertext_modulus);
if theoretical_cdf == 1.0 {
assert_eq!(empirical_cdf, 1.0);
}
let diff = empirical_cdf - theoretical_cdf;
diff.abs()
})
.max_by(f64::total_cmp)
.unwrap();
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
let upper_bound_for_cdf_abs_diff =
dkw_cdf_bands_width(number_of_samples, CONFIDENCE_INTERVAL);
@@ -545,6 +532,36 @@ fn test_random_from_distribution_custom_mod<Scalar, D>(
);
}
pub fn sup_diff(cumulative_bins: &[u64], theoretical_cdf: &[f64]) -> f64 {
let number_of_samples = *cumulative_bins.last().unwrap();
cumulative_bins
.iter()
.copied()
.zip_eq(theoretical_cdf.iter().copied())
.map(|(x, theoretical_cdf)| {
let empirical_cdf = x as f64 / number_of_samples as f64;
if theoretical_cdf == 1.0 {
assert_eq!(empirical_cdf, 1.0);
}
let diff = empirical_cdf - theoretical_cdf;
diff.abs()
})
.max_by(f64::total_cmp)
.unwrap()
}
pub fn cumulate<T: AddAssign + Default + Copy>(bins: &[T]) -> Vec<T> {
bins.iter()
.scan(T::default(), |sum, x| {
*sum += *x;
Some(*sum)
})
.collect()
}
impl<Scalar: UnsignedInteger + CastFrom<usize> + CastInto<usize>> DistributionTestHelper<Scalar>
for Uniform
{

View File

@@ -40,7 +40,6 @@ impl ServerKey {
let sk = &self.key;
assert!(self.message_modulus().0.is_power_of_two());
let message_bits_count = self.message_modulus().0.ilog2() as u64;
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
@@ -117,7 +116,6 @@ impl ServerKey {
"The range asked for a random value (=[0, 2^{random_bits_count}[) does not fit in the available range [0, 2^{range_log_size}[",
);
assert!(self.message_modulus().0.is_power_of_two());
let message_bits_count = self.message_modulus().0.ilog2() as u64;
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
@@ -154,6 +152,83 @@ impl ServerKey {
RadixCiphertext::from(blocks)
}
/// Generates an encrypted `num_blocks_output` blocks unsigned integer
/// taken almost uniformly in [0, excluded_upper_bound[ using the given seed.
/// The encrypted value is oblivious to the server.
/// It can be useful to make server random generation deterministic.
/// The higher num_input_random_bits, the closer to a uniform the distribution will be (at the
/// cost of computation time).
/// It is recommended to use a multiple of `log2_message_modulus`
/// as `num_input_random_bits`
///
/// ```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
/// use tfhe::Seed;
///
/// let size = 4;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
///
/// let num_input_random_bits = 5;
/// let excluded_upper_bound = 3;
/// let num_blocks_output = 8;
///
/// let ct_res = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range(
/// Seed(0),
/// num_input_random_bits,
/// excluded_upper_bound,
/// num_blocks_output,
/// );
///
/// // Decrypt:
/// let dec_result: u64 = cks.decrypt(&ct_res);
///
/// assert!(dec_result < excluded_upper_bound);
/// ```
pub fn par_generate_oblivious_pseudo_random_unsigned_custom_range(
&self,
seed: Seed,
num_input_random_bits: u64,
excluded_upper_bound: u64,
num_blocks_output: u64,
) -> RadixCiphertext {
assert!(self.message_modulus().0.is_power_of_two());
let message_bits_count = self.message_modulus().0.ilog2() as u64;
assert!(
!excluded_upper_bound.is_power_of_two(),
"Use the cheaper par_generate_oblivious_pseudo_random_unsigned_integer_bounded function instead"
);
let num_bits_output = num_blocks_output * message_bits_count;
assert!((excluded_upper_bound as f64) < 2_f64.powi(num_bits_output as i32), "num_blocks_output(={num_blocks_output}) is too small to hold an integer up to excluded_upper_bound(=excluded_upper_bound)");
let post_mul_num_bits =
num_input_random_bits + (excluded_upper_bound as f64).log2().ceil() as u64;
let num_blocks = post_mul_num_bits.div_ceil(message_bits_count);
let random_input = self.par_generate_oblivious_pseudo_random_unsigned_integer_bounded(
seed,
num_input_random_bits,
num_blocks,
);
let random_multiplied = self.scalar_mul_parallelized(&random_input, excluded_upper_bound);
let mut result =
self.scalar_right_shift_parallelized(&random_multiplied, num_input_random_bits);
// Adjust the number of leading (MSB) trivial zeros blocks
result
.blocks
.resize(num_blocks_output as usize, self.key.create_trivial(0));
result
}
}
impl ServerKey {
@@ -250,7 +325,6 @@ impl ServerKey {
);
}
assert!(self.message_modulus().0.is_power_of_two());
let message_bits_count = self.message_modulus().0.ilog2() as u64;
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
@@ -291,8 +365,14 @@ impl ServerKey {
#[cfg(test)]
pub(crate) mod test {
use crate::core_crypto::commons::math::random::tests::{
cumulate, dkw_alpha_from_epsilon, sup_diff,
};
use crate::integer::keycache::KEY_CACHE;
use crate::integer::IntegerKeyKind;
use crate::shortint::oprf::test::test_uniformity;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use tfhe_csprng::seeders::Seed;
#[test]
@@ -305,9 +385,10 @@ pub(crate) mod test {
let num_blocks = 2;
use crate::integer::gen_keys_radix;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
let (ck, sk) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks);
let (ck, sk) = KEY_CACHE.get_from_params(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
IntegerKeyKind::Radix,
);
let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
test_uniformity(sample_count, p_value_limit, distinct_values, f)
@@ -319,7 +400,7 @@ pub(crate) mod test {
random_bits_count,
num_blocks as u64,
);
ck.decrypt(&img)
ck.decrypt_radix(&img)
});
test_uniformity(1 << random_bits_count, &|seed| {
@@ -328,7 +409,7 @@ pub(crate) mod test {
random_bits_count,
num_blocks as u64,
);
let result = ck.decrypt_signed::<i64>(&img);
let result = ck.decrypt_signed_radix::<i64>(&img);
assert!(result >= 0);
@@ -342,11 +423,110 @@ pub(crate) mod test {
);
// Move from [-2^(p-1), 2^(p-1)[ to [0, 2^p[ (p = 2 * num_blocks)
let result = ck.decrypt_signed::<i64>(&img) + (1 << (2 * num_blocks - 1));
let result = ck.decrypt_signed_radix::<i64>(&img) + (1 << (2 * num_blocks - 1));
assert!(result >= 0);
result as u64
});
}
#[test]
fn oprf_test_any_range_ci_run_filter() {
let (ck, sk) = KEY_CACHE.get_from_params(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
IntegerKeyKind::Radix,
);
let num_loops = 100;
for seed in 0..num_loops {
let seed = Seed(seed);
for num_input_random_bits in [1, 2, 63, 64] {
for (excluded_upper_bound, num_blocks_output) in
[(3, 1), (3, 32), ((1 << 32) + 1, 64)]
{
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
seed,
num_input_random_bits,
excluded_upper_bound,
num_blocks_output as u64,
);
assert_eq!(img.blocks.len(), num_blocks_output);
let decrypted: u64 = ck.decrypt_radix(&img);
assert!(decrypted < excluded_upper_bound);
}
}
}
}
#[test]
fn oprf_test_almost_uniformity_ci_run_filter() {
let sample_count: usize = 10_000;
let p_value_limit: f64 = 0.001;
let num_input_random_bits: usize = 4;
let num_blocks_output = 64;
let excluded_upper_bound = 10;
let (ck, sk) = KEY_CACHE.get_from_params(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
IntegerKeyKind::Radix,
);
let random_input_upper_bound = 1 << num_input_random_bits;
let mut density = vec![0_usize; excluded_upper_bound];
for i in 0..random_input_upper_bound {
let index =
((i * excluded_upper_bound) as f64 / random_input_upper_bound as f64) as usize;
density[index] += 1;
}
//probability density function
let theoretical_pdf: Vec<f64> = density
.iter()
.map(|count| *count as f64 / random_input_upper_bound as f64)
.collect();
let values: Vec<u64> = (0..sample_count)
.into_par_iter()
.map(|seed| {
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
Seed(seed as u128),
num_input_random_bits as u64,
excluded_upper_bound as u64,
num_blocks_output as u64,
);
ck.decrypt_radix(&img)
})
.collect();
let mut bins = vec![0_u64; excluded_upper_bound];
for value in values {
bins[value as usize] += 1;
}
let cumulative_bins = cumulate(&bins);
let theoretical_cdf = cumulate(&theoretical_pdf);
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
let p_value_upper_bound = dkw_alpha_from_epsilon(sample_count as f64, sup_diff);
println!("p_value_upper_bound {p_value_upper_bound}");
assert!(p_value_limit < p_value_upper_bound);
}
}