refactor(gpu): oprf_unsigned_custom_range + tests

This commit is contained in:
Enzo Di Maria
2025-10-15 16:38:55 +02:00
committed by Agnès Leroy
parent 353237c0d6
commit 126e779533
11 changed files with 230 additions and 160 deletions

View File

@@ -700,10 +700,10 @@ uint64_t scratch_cuda_integer_grouped_oprf_custom_range_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks_to_process,
uint32_t grouping_factor, uint32_t num_blocks_intermediate,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, uint32_t message_bits_per_block,
uint32_t total_random_bits, uint32_t num_scalar_bits,
uint32_t num_input_random_bits, uint32_t num_scalar_bits,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_integer_grouped_oprf_custom_range_64(

View File

@@ -157,32 +157,37 @@ template <typename Torus> struct int_grouped_oprf_custom_range_memory {
int_scalar_mul_buffer<Torus> *scalar_mul_buffer;
int_logical_scalar_shift_buffer<Torus> *logical_scalar_shift_buffer;
CudaRadixCiphertextFFI *tmp_oprf_output;
uint32_t num_random_input_blocks;
int_grouped_oprf_custom_range_memory(
CudaStreams streams, int_radix_params params,
uint32_t num_blocks_to_process, uint32_t message_bits_per_block,
uint64_t total_random_bits, uint32_t num_scalar_bits,
uint32_t num_blocks_intermediate, uint32_t message_bits_per_block,
uint64_t num_input_random_bits, uint32_t num_scalar_bits,
bool allocate_gpu_memory, uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_random_input_blocks =
(num_input_random_bits + message_bits_per_block - 1) /
message_bits_per_block;
this->grouped_oprf_memory = new int_grouped_oprf_memory<Torus>(
streams, params, num_blocks_to_process, message_bits_per_block,
total_random_bits, allocate_gpu_memory, size_tracker);
streams, params, this->num_random_input_blocks, message_bits_per_block,
num_input_random_bits, allocate_gpu_memory, size_tracker);
this->scalar_mul_buffer = new int_scalar_mul_buffer<Torus>(
streams, params, num_blocks_to_process, num_scalar_bits,
streams, params, num_blocks_intermediate, num_scalar_bits,
allocate_gpu_memory, true, size_tracker);
this->logical_scalar_shift_buffer =
new int_logical_scalar_shift_buffer<Torus>(
streams, RIGHT_SHIFT, params, num_blocks_to_process,
streams, RIGHT_SHIFT, params, num_blocks_intermediate,
allocate_gpu_memory, size_tracker);
this->tmp_oprf_output = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->tmp_oprf_output,
num_blocks_to_process, params.big_lwe_dimension, size_tracker,
num_blocks_intermediate, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
@@ -204,5 +209,7 @@ template <typename Torus> struct int_grouped_oprf_custom_range_memory {
this->allocate_gpu_memory);
delete this->tmp_oprf_output;
this->tmp_oprf_output = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};

View File

@@ -48,10 +48,10 @@ uint64_t scratch_cuda_integer_grouped_oprf_custom_range_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks_to_process,
uint32_t grouping_factor, uint32_t num_blocks_intermediate,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, uint32_t message_bits_per_block,
uint32_t total_random_bits, uint32_t num_scalar_bits,
uint32_t num_input_random_bits, uint32_t num_scalar_bits,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
@@ -63,7 +63,7 @@ uint64_t scratch_cuda_integer_grouped_oprf_custom_range_64(
return scratch_cuda_integer_grouped_oprf_custom_range<uint64_t>(
CudaStreams(streams),
(int_grouped_oprf_custom_range_memory<uint64_t> **)mem_ptr, params,
num_blocks_to_process, message_bits_per_block, total_random_bits,
num_blocks_intermediate, message_bits_per_block, num_input_random_bits,
num_scalar_bits, allocate_gpu_memory);
}

View File

@@ -93,14 +93,15 @@ void host_integer_grouped_oprf(CudaStreams streams,
template <typename Torus>
uint64_t scratch_cuda_integer_grouped_oprf_custom_range(
CudaStreams streams, int_grouped_oprf_custom_range_memory<Torus> **mem_ptr,
int_radix_params params, uint32_t num_blocks_to_process,
uint32_t message_bits_per_block, uint64_t total_random_bits,
int_radix_params params, uint32_t num_blocks_intermediate,
uint32_t message_bits_per_block, uint64_t num_input_random_bits,
uint32_t num_scalar_bits, bool allocate_gpu_memory) {
uint64_t size_tracker = 0;
*mem_ptr = new int_grouped_oprf_custom_range_memory<Torus>(
streams, params, num_blocks_to_process, message_bits_per_block,
total_random_bits, num_scalar_bits, allocate_gpu_memory, size_tracker);
streams, params, num_blocks_intermediate, message_bits_per_block,
num_input_random_bits, num_scalar_bits, allocate_gpu_memory,
size_tracker);
return size_tracker;
}
@@ -114,40 +115,39 @@ void host_integer_grouped_oprf_custom_range(
int_grouped_oprf_custom_range_memory<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
uint32_t num_blocks_output = radix_lwe_out->num_radix_blocks;
if (num_blocks_output < num_blocks_intermediate) {
num_blocks_intermediate = num_blocks_output;
}
CudaRadixCiphertextFFI intermediate_slice;
as_radix_ciphertext_slice<Torus>(&intermediate_slice, radix_lwe_out, 0,
num_blocks_intermediate);
host_integer_grouped_oprf<Torus>(streams, mem_ptr->tmp_oprf_output,
seeded_lwe_input, num_blocks_intermediate,
mem_ptr->grouped_oprf_memory, bsks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), &intermediate_slice, 0,
num_blocks_intermediate, mem_ptr->tmp_oprf_output, 0,
CudaRadixCiphertextFFI *computation_buffer = mem_ptr->tmp_oprf_output;
set_zero_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), computation_buffer, 0,
num_blocks_intermediate);
host_integer_grouped_oprf<Torus>(
streams, computation_buffer, seeded_lwe_input,
mem_ptr->num_random_input_blocks, mem_ptr->grouped_oprf_memory, bsks);
host_integer_scalar_mul_radix<Torus>(
streams, &intermediate_slice, decomposed_scalar, has_at_least_one_set,
streams, computation_buffer, decomposed_scalar, has_at_least_one_set,
mem_ptr->scalar_mul_buffer, bsks, ksks, mem_ptr->params.message_modulus,
num_scalars);
host_logical_scalar_shift_inplace<Torus>(streams, &intermediate_slice, shift,
host_logical_scalar_shift_inplace<Torus>(streams, computation_buffer, shift,
mem_ptr->logical_scalar_shift_buffer,
bsks, ksks, num_blocks_intermediate);
if (num_blocks_output > num_blocks_intermediate) {
set_zero_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix_lwe_out,
num_blocks_intermediate, num_blocks_output);
uint32_t num_blocks_output = radix_lwe_out->num_radix_blocks;
uint32_t blocks_to_copy =
std::min(num_blocks_output, num_blocks_intermediate);
if (blocks_to_copy > 0) {
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix_lwe_out, 0,
blocks_to_copy, computation_buffer, 0, blocks_to_copy);
}
if (num_blocks_output > blocks_to_copy) {
set_zero_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix_lwe_out, blocks_to_copy,
num_blocks_output);
}
radix_lwe_out->num_radix_blocks = num_blocks_output;
}
#endif

View File

@@ -1558,13 +1558,13 @@ unsafe extern "C" {
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_blocks_to_process: u32,
num_blocks_intermediate: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
message_bits_per_block: u32,
total_random_bits: u32,
num_input_random_bits: u32,
num_scalar_bits: u32,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;

View File

@@ -2828,7 +2828,7 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger,
carry_modulus: CarryModulus,
pbs_type: PBSType,
message_bits_per_block: u32,
total_random_bits: u32,
_total_random_bits: u32,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) {
assert_eq!(
@@ -2871,13 +2871,13 @@ pub(crate) unsafe fn cuda_backend_grouped_oprf_custom_range<T: UnsignedInteger,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
radix_lwe_out.d_blocks.0.d_vec.len() as u32,
num_blocks_intermediate,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
message_bits_per_block,
total_random_bits,
shift,
num_scalars,
noise_reduction_type as u32,
);

View File

@@ -502,7 +502,7 @@ impl CudaServerKey {
streams.synchronize();
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_blocks_intermediate as usize, streams);
self.create_trivial_zero_radix(num_blocks_output as usize, streams);
unsafe {
match &self.bootstrapping_key {
@@ -563,7 +563,6 @@ impl CudaServerKey {
}
}
streams.synchronize();
result
}

View File

@@ -1064,6 +1064,34 @@ where
}
}
impl<F> FunctionExecutor<(Seed, u64, u64, u64), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, Seed, u64, u64, u64, &CudaStreams) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (Seed, u64, u64, u64)) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let gpu_result = (self.func)(
&context.sks,
input.0,
input.1,
input.2,
input.3,
&context.streams,
);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
impl<F> FunctionExecutor<(Seed, u64, u64), RadixCiphertext> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, Seed, u64, u64, &CudaStreams) -> CudaUnsignedRadixCiphertext,

View File

@@ -1,7 +1,9 @@
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_unsigned::test_oprf::oprf_uniformity_test;
use crate::integer::server_key::radix_parallel::tests_unsigned::test_oprf::{
oprf_almost_uniformity_test, oprf_any_range_test, oprf_uniformity_test,
};
use crate::shortint::parameters::{
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
};
@@ -9,6 +11,12 @@ use crate::shortint::parameters::{
create_gpu_parameterized_test!(oprf_uniformity_unsigned {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(oprf_any_range_unsigned {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(oprf_almost_uniformity_unsigned {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
fn oprf_uniformity_unsigned<P>(param: P)
where
@@ -19,3 +27,23 @@ where
);
oprf_uniformity_test(param, executor);
}
fn oprf_any_range_unsigned<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuMultiDeviceFunctionExecutor::new(
&CudaServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
);
oprf_any_range_test(param, executor);
}
fn oprf_almost_uniformity_unsigned<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuMultiDeviceFunctionExecutor::new(
&CudaServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
);
oprf_almost_uniformity_test(param, executor);
}

View File

@@ -362,114 +362,3 @@ impl ServerKey {
SignedRadixCiphertext::from(blocks)
}
}
#[cfg(test)]
pub(crate) mod test {
use crate::core_crypto::commons::math::random::tests::{
cumulate, dkw_alpha_from_epsilon, sup_diff,
};
use crate::integer::keycache::KEY_CACHE;
use crate::integer::IntegerKeyKind;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use tfhe_csprng::seeders::Seed;
#[test]
fn oprf_test_any_range_ci_run_filter() {
let (ck, sk) = KEY_CACHE.get_from_params(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
IntegerKeyKind::Radix,
);
let num_loops = 100;
for seed in 0..num_loops {
let seed = Seed(seed);
for num_input_random_bits in [1, 2, 63, 64] {
for (excluded_upper_bound, num_blocks_output) in
[(3, 1), (3, 32), ((1 << 32) + 1, 64)]
{
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
seed,
num_input_random_bits,
excluded_upper_bound,
num_blocks_output as u64,
);
assert_eq!(img.blocks.len(), num_blocks_output);
let decrypted: u64 = ck.decrypt_radix(&img);
assert!(decrypted < excluded_upper_bound);
}
}
}
}
#[test]
fn oprf_test_almost_uniformity_ci_run_filter() {
let sample_count: usize = 10_000;
let p_value_limit: f64 = 0.001;
let num_input_random_bits: usize = 4;
let num_blocks_output = 64;
let excluded_upper_bound = 10;
let (ck, sk) = KEY_CACHE.get_from_params(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
IntegerKeyKind::Radix,
);
let random_input_upper_bound = 1 << num_input_random_bits;
let mut density = vec![0_usize; excluded_upper_bound];
for i in 0..random_input_upper_bound {
let index =
((i * excluded_upper_bound) as f64 / random_input_upper_bound as f64) as usize;
density[index] += 1;
}
//probability density function
let theoretical_pdf: Vec<f64> = density
.iter()
.map(|count| *count as f64 / random_input_upper_bound as f64)
.collect();
let values: Vec<u64> = (0..sample_count)
.into_par_iter()
.map(|seed| {
let img = sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
Seed(seed as u128),
num_input_random_bits as u64,
excluded_upper_bound as u64,
num_blocks_output as u64,
);
ck.decrypt_radix(&img)
})
.collect();
let mut bins = vec![0_u64; excluded_upper_bound];
for value in values {
bins[value as usize] += 1;
}
let cumulative_bins = cumulate(&bins);
let theoretical_cdf = cumulate(&theoretical_pdf);
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
let p_value_upper_bound = dkw_alpha_from_epsilon(sample_count as f64, sup_diff);
println!("p_value_upper_bound {p_value_upper_bound}");
assert!(p_value_limit < p_value_upper_bound);
}
}

View File

@@ -1,3 +1,6 @@
use crate::core_crypto::commons::math::random::tests::{
cumulate, dkw_alpha_from_epsilon, sup_diff,
};
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
@@ -10,7 +13,13 @@ use std::sync::Arc;
use tfhe_csprng::seeders::Seed;
create_parameterized_test!(oprf_uniformity_unsigned {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128
});
create_parameterized_test!(oprf_any_range_unsigned {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128
});
create_parameterized_test!(oprf_almost_uniformity_unsigned {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128
});
fn oprf_uniformity_unsigned<P>(param: P)
@@ -23,6 +32,26 @@ where
oprf_uniformity_test(param, executor);
}
fn oprf_any_range_unsigned<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
);
oprf_any_range_test(param, executor);
}
fn oprf_almost_uniformity_unsigned<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = CpuFunctionExecutor::new(
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
);
oprf_almost_uniformity_test(param, executor);
}
fn square(a: f64) -> f64 {
a * a
}
@@ -83,3 +112,93 @@ where
cks.decrypt(&img)
});
}
pub fn oprf_any_range_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(Seed, u64, u64, u64), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let num_loops = 100;
for s in 0..num_loops {
let seed = Seed(s);
for num_input_random_bits in [1, 2, 63, 64] {
for (excluded_upper_bound, num_blocks_output) in [(3, 1), (3, 32), ((1 << 32) + 1, 64)]
{
let img = executor.execute((
seed,
num_input_random_bits,
excluded_upper_bound,
num_blocks_output as u64,
));
assert_eq!(img.blocks.len(), num_blocks_output);
let decrypted: u64 = cks.decrypt(&img);
assert!(decrypted < excluded_upper_bound);
}
}
}
}
pub fn oprf_almost_uniformity_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(Seed, u64, u64, u64), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let sample_count: usize = 10_000;
let p_value_limit: f64 = 0.001;
let num_input_random_bits: u64 = 4;
let num_blocks_output = 64;
let excluded_upper_bound = 10;
let random_input_upper_bound = 1 << num_input_random_bits;
let mut density = vec![0_usize; excluded_upper_bound as usize];
for i in 0..random_input_upper_bound {
let index = ((i * excluded_upper_bound) as f64 / random_input_upper_bound as f64) as usize;
density[index] += 1;
}
let theoretical_pdf: Vec<f64> = density
.iter()
.map(|count| *count as f64 / random_input_upper_bound as f64)
.collect();
let values: Vec<u64> = (0..sample_count)
.map(|seed| {
let img = executor.execute((
Seed(seed as u128),
num_input_random_bits,
excluded_upper_bound as u64,
num_blocks_output,
));
cks.decrypt(&img)
})
.collect();
let mut bins = vec![0_u64; excluded_upper_bound as usize];
for value in values {
bins[value as usize] += 1;
}
let cumulative_bins = cumulate(&bins);
let theoretical_cdf = cumulate(&theoretical_pdf);
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
let p_value_upper_bound = dkw_alpha_from_epsilon(sample_count as f64, sup_diff);
assert!(p_value_limit < p_value_upper_bound);
}