mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(integer): add oprf over any range
This commit is contained in:
@@ -29,7 +29,7 @@ pub use uniform_binary::*;
|
||||
pub use uniform_ternary::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub(crate) mod tests;
|
||||
|
||||
mod gaussian;
|
||||
mod generator;
|
||||
|
||||
@@ -4,6 +4,8 @@ use crate::core_crypto::commons::math::random::{Distribution, RandomGenerable, T
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, Numeric, UnsignedInteger};
|
||||
use crate::core_crypto::commons::test_tools::*;
|
||||
use itertools::Itertools;
|
||||
use std::ops::AddAssign;
|
||||
|
||||
fn test_normal_random_three_sigma<T: UnsignedTorus>() {
|
||||
//! test if the normal random generation with std_dev is below 3*std_dev (99.7%)
|
||||
@@ -432,14 +434,20 @@ fn dkw_cdf_bands_width(number_of_samples: usize, confidence_interval: f64) -> f6
|
||||
// https://en.wikipedia.org/wiki/Dvoretzky%E2%80%93Kiefer%E2%80%93Wolfowitz_inequality#Building_CDF_bands
|
||||
// the true CDF is between the empirical CDF +/- this band width with probability 1 - alpha
|
||||
// Said otherwise, the abs diff should be less than that value with high probability
|
||||
fn dkw_cdf_bands_width_formula(sample_size: f64, alpha: f64) -> f64 {
|
||||
f64::sqrt(f64::ln(2.0 / alpha) / (2.0 * sample_size))
|
||||
}
|
||||
|
||||
// alpha = 1 - probability of being in the interval
|
||||
dkw_cdf_bands_width_formula(number_of_samples as f64, 1.0 - confidence_interval)
|
||||
}
|
||||
|
||||
pub fn dkw_cdf_bands_width_formula(sample_size: f64, alpha: f64) -> f64 {
|
||||
f64::sqrt(f64::ln(2.0 / alpha) / (2.0 * sample_size))
|
||||
}
|
||||
|
||||
// https://en.wikipedia.org/wiki/Dvoretzky%E2%80%93Kiefer%E2%80%93Wolfowitz_inequality#The_DKW_inequality
|
||||
pub fn dkw_alpha_from_epsilon(sample_size: f64, epsilon: f64) -> f64 {
|
||||
2.0 * (-epsilon * epsilon * (2.0 * sample_size)).exp()
|
||||
}
|
||||
|
||||
fn test_random_from_distribution_custom_mod<Scalar, D>(
|
||||
creation_infos: D::CreationInfos,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
@@ -490,43 +498,22 @@ fn test_random_from_distribution_custom_mod<Scalar, D>(
|
||||
}
|
||||
}
|
||||
|
||||
let mut cumulative_sums = vec![0u64; distinct_values];
|
||||
let cumulative_bins = cumulate(&bins);
|
||||
|
||||
let mut curr_sum = 0;
|
||||
|
||||
// Compute the cumulative sums
|
||||
for (bin_count, cum_sum) in bins.iter().zip(cumulative_sums.iter_mut()) {
|
||||
curr_sum += bin_count;
|
||||
*cum_sum = curr_sum;
|
||||
}
|
||||
let theoretical_cdf: Vec<f64> = (0..distinct_values)
|
||||
.map(|bin_idx| {
|
||||
let integer_value: Scalar =
|
||||
distribution.map_usize_to_value(bin_idx, ciphertext_modulus);
|
||||
// CDF for the uniform distribution
|
||||
distribution.cumulative_distribution_function(integer_value, ciphertext_modulus)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Inaccurate if modulus >~ 2^53 / number_of_samples_per_bin, but if that's the case your
|
||||
// memory most likely blew up before (or the universe died its heat death)
|
||||
let number_of_samples = NUMBER_OF_SAMPLES_PER_VALUE * distinct_values;
|
||||
|
||||
let sup_diff: f64 = cumulative_sums
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.map(|(bin_idx, x)| {
|
||||
// Compute the observed CDF
|
||||
let empirical_cdf = x as f64 / number_of_samples as f64;
|
||||
|
||||
let integer_value: Scalar =
|
||||
distribution.map_usize_to_value(bin_idx, ciphertext_modulus);
|
||||
// CDF for the uniform distribution
|
||||
let theoretical_cdf = distribution
|
||||
.cumulative_distribution_function(integer_value, ciphertext_modulus);
|
||||
|
||||
if theoretical_cdf == 1.0 {
|
||||
assert_eq!(empirical_cdf, 1.0);
|
||||
}
|
||||
|
||||
let diff = empirical_cdf - theoretical_cdf;
|
||||
diff.abs()
|
||||
})
|
||||
.max_by(f64::total_cmp)
|
||||
.unwrap();
|
||||
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
|
||||
|
||||
let upper_bound_for_cdf_abs_diff =
|
||||
dkw_cdf_bands_width(number_of_samples, CONFIDENCE_INTERVAL);
|
||||
@@ -545,6 +532,36 @@ fn test_random_from_distribution_custom_mod<Scalar, D>(
|
||||
);
|
||||
}
|
||||
|
||||
pub fn sup_diff(cumulative_bins: &[u64], theoretical_cdf: &[f64]) -> f64 {
|
||||
let number_of_samples = *cumulative_bins.last().unwrap();
|
||||
|
||||
cumulative_bins
|
||||
.iter()
|
||||
.copied()
|
||||
.zip_eq(theoretical_cdf.iter().copied())
|
||||
.map(|(x, theoretical_cdf)| {
|
||||
let empirical_cdf = x as f64 / number_of_samples as f64;
|
||||
|
||||
if theoretical_cdf == 1.0 {
|
||||
assert_eq!(empirical_cdf, 1.0);
|
||||
}
|
||||
|
||||
let diff = empirical_cdf - theoretical_cdf;
|
||||
diff.abs()
|
||||
})
|
||||
.max_by(f64::total_cmp)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn cumulate<T: AddAssign + Default + Copy>(bins: &[T]) -> Vec<T> {
|
||||
bins.iter()
|
||||
.scan(T::default(), |sum, x| {
|
||||
*sum += *x;
|
||||
Some(*sum)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl<Scalar: UnsignedInteger + CastFrom<usize> + CastInto<usize>> DistributionTestHelper<Scalar>
|
||||
for Uniform
|
||||
{
|
||||
|
||||
@@ -40,7 +40,6 @@ impl ServerKey {
|
||||
|
||||
let sk = &self.key;
|
||||
|
||||
assert!(self.message_modulus().0.is_power_of_two());
|
||||
let message_bits_count = self.message_modulus().0.ilog2() as u64;
|
||||
|
||||
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
@@ -117,7 +116,6 @@ impl ServerKey {
|
||||
"The range asked for a random value (=[0, 2^{random_bits_count}[) does not fit in the available range [0, 2^{range_log_size}[",
|
||||
);
|
||||
|
||||
assert!(self.message_modulus().0.is_power_of_two());
|
||||
let message_bits_count = self.message_modulus().0.ilog2() as u64;
|
||||
|
||||
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
@@ -154,6 +152,83 @@ impl ServerKey {
|
||||
|
||||
RadixCiphertext::from(blocks)
|
||||
}
|
||||
|
||||
/// Generates an encrypted `num_blocks_output` blocks unsigned integer
|
||||
/// taken almost uniformly in [0, excluded_upper_bound[ using the given seed.
|
||||
/// The encrypted value is oblivious to the server.
|
||||
/// It can be useful to make server random generation deterministic.
|
||||
/// The higher num_input_random_bits, the closer to a uniform the distribution will be (at the
|
||||
/// cost of computation time).
|
||||
/// It is recommended to use a multiple of `log2_message_modulus`
|
||||
/// as `num_input_random_bits`
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
|
||||
/// use tfhe::Seed;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
|
||||
///
|
||||
/// let num_input_random_bits = 5;
|
||||
/// let excluded_upper_bound = 3;
|
||||
/// let num_blocks_output = 8;
|
||||
///
|
||||
/// let ct_res = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
/// Seed(0),
|
||||
/// num_input_random_bits,
|
||||
/// excluded_upper_bound,
|
||||
/// num_blocks_output,
|
||||
/// );
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result: u64 = cks.decrypt(&ct_res);
|
||||
///
|
||||
/// assert!(dec_result < excluded_upper_bound);
|
||||
/// ```
|
||||
pub fn par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
&self,
|
||||
seed: Seed,
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
num_blocks_output: u64,
|
||||
) -> RadixCiphertext {
|
||||
assert!(self.message_modulus().0.is_power_of_two());
|
||||
let message_bits_count = self.message_modulus().0.ilog2() as u64;
|
||||
|
||||
assert!(
|
||||
!excluded_upper_bound.is_power_of_two(),
|
||||
"Use the cheaper par_generate_oblivious_pseudo_random_unsigned_integer_bounded function instead"
|
||||
);
|
||||
|
||||
let num_bits_output = num_blocks_output * message_bits_count;
|
||||
assert!((excluded_upper_bound as f64) < 2_f64.powi(num_bits_output as i32), "num_blocks_output(={num_blocks_output}) is too small to hold an integer up to excluded_upper_bound(=excluded_upper_bound)");
|
||||
|
||||
let post_mul_num_bits =
|
||||
num_input_random_bits + (excluded_upper_bound as f64).log2().ceil() as u64;
|
||||
|
||||
let num_blocks = post_mul_num_bits.div_ceil(message_bits_count);
|
||||
|
||||
let random_input = self.par_generate_oblivious_pseudo_random_unsigned_integer_bounded(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
num_blocks,
|
||||
);
|
||||
|
||||
let random_multiplied = self.scalar_mul_parallelized(&random_input, excluded_upper_bound);
|
||||
|
||||
let mut result =
|
||||
self.scalar_right_shift_parallelized(&random_multiplied, num_input_random_bits);
|
||||
|
||||
// Adjust the number of leading (MSB) trivial zeros blocks
|
||||
result
|
||||
.blocks
|
||||
.resize(num_blocks_output as usize, self.key.create_trivial(0));
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
@@ -250,7 +325,6 @@ impl ServerKey {
|
||||
);
|
||||
}
|
||||
|
||||
assert!(self.message_modulus().0.is_power_of_two());
|
||||
let message_bits_count = self.message_modulus().0.ilog2() as u64;
|
||||
|
||||
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
@@ -291,8 +365,14 @@ impl ServerKey {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
|
||||
use crate::core_crypto::commons::math::random::tests::{
|
||||
cumulate, dkw_alpha_from_epsilon, sup_diff,
|
||||
};
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::IntegerKeyKind;
|
||||
use crate::shortint::oprf::test::test_uniformity;
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use tfhe_csprng::seeders::Seed;
|
||||
|
||||
#[test]
|
||||
@@ -305,9 +385,10 @@ pub(crate) mod test {
|
||||
|
||||
let num_blocks = 2;
|
||||
|
||||
use crate::integer::gen_keys_radix;
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
|
||||
let (ck, sk) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks);
|
||||
let (ck, sk) = KEY_CACHE.get_from_params(
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
IntegerKeyKind::Radix,
|
||||
);
|
||||
|
||||
let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
|
||||
test_uniformity(sample_count, p_value_limit, distinct_values, f)
|
||||
@@ -319,7 +400,7 @@ pub(crate) mod test {
|
||||
random_bits_count,
|
||||
num_blocks as u64,
|
||||
);
|
||||
ck.decrypt(&img)
|
||||
ck.decrypt_radix(&img)
|
||||
});
|
||||
|
||||
test_uniformity(1 << random_bits_count, &|seed| {
|
||||
@@ -328,7 +409,7 @@ pub(crate) mod test {
|
||||
random_bits_count,
|
||||
num_blocks as u64,
|
||||
);
|
||||
let result = ck.decrypt_signed::<i64>(&img);
|
||||
let result = ck.decrypt_signed_radix::<i64>(&img);
|
||||
|
||||
assert!(result >= 0);
|
||||
|
||||
@@ -342,11 +423,110 @@ pub(crate) mod test {
|
||||
);
|
||||
|
||||
// Move from [-2^(p-1), 2^(p-1)[ to [0, 2^p[ (p = 2 * num_blocks)
|
||||
let result = ck.decrypt_signed::<i64>(&img) + (1 << (2 * num_blocks - 1));
|
||||
let result = ck.decrypt_signed_radix::<i64>(&img) + (1 << (2 * num_blocks - 1));
|
||||
|
||||
assert!(result >= 0);
|
||||
|
||||
result as u64
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oprf_test_any_range_ci_run_filter() {
|
||||
let (ck, sk) = KEY_CACHE.get_from_params(
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
IntegerKeyKind::Radix,
|
||||
);
|
||||
|
||||
let num_loops = 100;
|
||||
|
||||
for seed in 0..num_loops {
|
||||
let seed = Seed(seed);
|
||||
|
||||
for num_input_random_bits in [1, 2, 63, 64] {
|
||||
for (excluded_upper_bound, num_blocks_output) in
|
||||
[(3, 1), (3, 32), ((1 << 32) + 1, 64)]
|
||||
{
|
||||
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
excluded_upper_bound,
|
||||
num_blocks_output as u64,
|
||||
);
|
||||
|
||||
assert_eq!(img.blocks.len(), num_blocks_output);
|
||||
|
||||
let decrypted: u64 = ck.decrypt_radix(&img);
|
||||
|
||||
assert!(decrypted < excluded_upper_bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oprf_test_almost_uniformity_ci_run_filter() {
|
||||
let sample_count: usize = 10_000;
|
||||
|
||||
let p_value_limit: f64 = 0.001;
|
||||
|
||||
let num_input_random_bits: usize = 4;
|
||||
|
||||
let num_blocks_output = 64;
|
||||
|
||||
let excluded_upper_bound = 10;
|
||||
|
||||
let (ck, sk) = KEY_CACHE.get_from_params(
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
IntegerKeyKind::Radix,
|
||||
);
|
||||
|
||||
let random_input_upper_bound = 1 << num_input_random_bits;
|
||||
|
||||
let mut density = vec![0_usize; excluded_upper_bound];
|
||||
|
||||
for i in 0..random_input_upper_bound {
|
||||
let index =
|
||||
((i * excluded_upper_bound) as f64 / random_input_upper_bound as f64) as usize;
|
||||
|
||||
density[index] += 1;
|
||||
}
|
||||
|
||||
//probability density function
|
||||
let theoretical_pdf: Vec<f64> = density
|
||||
.iter()
|
||||
.map(|count| *count as f64 / random_input_upper_bound as f64)
|
||||
.collect();
|
||||
|
||||
let values: Vec<u64> = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|seed| {
|
||||
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
Seed(seed as u128),
|
||||
num_input_random_bits as u64,
|
||||
excluded_upper_bound as u64,
|
||||
num_blocks_output as u64,
|
||||
);
|
||||
ck.decrypt_radix(&img)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut bins = vec![0_u64; excluded_upper_bound];
|
||||
|
||||
for value in values {
|
||||
bins[value as usize] += 1;
|
||||
}
|
||||
|
||||
let cumulative_bins = cumulate(&bins);
|
||||
|
||||
let theoretical_cdf = cumulate(&theoretical_pdf);
|
||||
|
||||
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
|
||||
|
||||
let p_value_upper_bound = dkw_alpha_from_epsilon(sample_count as f64, sup_diff);
|
||||
|
||||
println!("p_value_upper_bound {p_value_upper_bound}");
|
||||
|
||||
assert!(p_value_limit < p_value_upper_bound);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user