From 22ddba71450f3e32c5892c6f6e40ae0e98701eca Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Wed, 2 Jul 2025 09:34:01 -0300 Subject: [PATCH] fix(gpu): refactor the (128-bit and regular) classical PBS entry point to remove the num_samples parameter - fixes the throughput for those PBSs - also fixes the throughput benchmark for regular PBSs --- tfhe-benchmark/benches/core_crypto/ks_pbs_bench.rs | 2 -- tfhe-benchmark/benches/core_crypto/pbs128_bench.rs | 11 ++++++----- tfhe-benchmark/benches/core_crypto/pbs_bench.rs | 4 +--- tfhe-benchmark/src/utilities.rs | 4 +++- .../gpu/algorithms/lwe_programmable_bootstrapping.rs | 11 +++-------- .../algorithms/test/lwe_programmable_bootstrapping.rs | 1 - .../test/lwe_programmable_bootstrapping_128.rs | 2 -- .../lwe_programmable_bootstrapping_noise.rs | 1 - tfhe/src/integer/gpu/server_key/radix/oprf.rs | 3 +-- 9 files changed, 14 insertions(+), 25 deletions(-) diff --git a/tfhe-benchmark/benches/core_crypto/ks_pbs_bench.rs b/tfhe-benchmark/benches/core_crypto/ks_pbs_bench.rs index 9e959db89..78a3b88cd 100644 --- a/tfhe-benchmark/benches/core_crypto/ks_pbs_bench.rs +++ b/tfhe-benchmark/benches/core_crypto/ks_pbs_bench.rs @@ -640,7 +640,6 @@ mod cuda { &cuda_indexes.d_lut, &cuda_indexes.d_output, &cuda_indexes.d_input, - LweCiphertextCount(1), gpu_keys.bsk.as_ref().unwrap(), &streams, ); @@ -793,7 +792,6 @@ mod cuda { &cuda_indexes_vec[i].d_lut, &cuda_indexes_vec[i].d_output, &cuda_indexes_vec[i].d_input, - LweCiphertextCount(1), gpu_keys_vec[i].bsk.as_ref().unwrap(), local_stream, ); diff --git a/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs b/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs index 2f65b20b6..8cb90e6c3 100644 --- a/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs +++ b/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs @@ -302,7 +302,6 @@ mod cuda { &lwe_ciphertext_in_gpu, &mut out_pbs_ct_gpu, &accumulator_gpu, - LweCiphertextCount(1), gpu_keys.bsk.as_ref().unwrap(), &streams, ); @@ -398,12 +397,14 @@ mod cuda { .zip(accumulators.par_iter()) .zip(local_streams.par_iter()) .for_each( - |((((i, input_ct), output_ct), accumulator), local_stream)| { + |( + (((i, input_batch), output_batch), accumulator), + local_stream, + )| { cuda_programmable_bootstrap_128_lwe_ciphertext( - input_ct, - output_ct, + input_batch, + output_batch, accumulator, - LweCiphertextCount(1), gpu_keys_vec[i].bsk.as_ref().unwrap(), local_stream, ); diff --git a/tfhe-benchmark/benches/core_crypto/pbs_bench.rs b/tfhe-benchmark/benches/core_crypto/pbs_bench.rs index 2d6bf17f9..fb1ee9285 100644 --- a/tfhe-benchmark/benches/core_crypto/pbs_bench.rs +++ b/tfhe-benchmark/benches/core_crypto/pbs_bench.rs @@ -1031,7 +1031,6 @@ mod cuda { &cuda_indexes.d_lut, &cuda_indexes.d_output, &cuda_indexes.d_input, - LweCiphertextCount(1), gpu_keys.bsk.as_ref().unwrap(), &streams, ); @@ -1113,7 +1112,7 @@ mod cuda { }) .collect::>(); - let h_indexes = (0..(elements / gpu_count as u64)) + let h_indexes = (0..elements_per_stream as u64) .map(CastFrom::cast_from) .collect::>(); let cuda_indexes_vec = (0..gpu_count) @@ -1157,7 +1156,6 @@ mod cuda { &cuda_indexes_vec[i].d_lut, &cuda_indexes_vec[i].d_output, &cuda_indexes_vec[i].d_input, - LweCiphertextCount(1), gpu_keys_vec[i].bsk.as_ref().unwrap(), local_stream, ); diff --git a/tfhe-benchmark/src/utilities.rs b/tfhe-benchmark/src/utilities.rs index 462a90b9f..4330e07ad 100644 --- a/tfhe-benchmark/src/utilities.rs +++ b/tfhe-benchmark/src/utilities.rs @@ -579,10 +579,12 @@ mod cuda_utils { let mut d_input = unsafe { CudaVec::::new_async(length, stream, stream_index) }; let mut d_output = unsafe { CudaVec::::new_async(length, stream, stream_index) }; let mut d_lut = unsafe { CudaVec::::new_async(length, stream, stream_index) }; + let zeros = vec![T::ZERO; length]; + unsafe { d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index); d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index); - d_lut.copy_from_cpu_async(indexes.as_ref(), stream, stream_index); + d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index); } stream.synchronize(); diff --git a/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs index 6e39aa0c7..318573330 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs @@ -5,7 +5,7 @@ use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::{ programmable_bootstrap_128_async, programmable_bootstrap_async, CudaStreams, }; -use crate::core_crypto::prelude::{CastInto, LweCiphertextCount, UnsignedTorus}; +use crate::core_crypto::prelude::{CastInto, UnsignedTorus}; /// # Safety /// @@ -19,7 +19,6 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async( lut_indexes: &CudaVec, output_indexes: &CudaVec, input_indexes: &CudaVec, - num_samples: LweCiphertextCount, bsk: &CudaLweBootstrapKey, streams: &CudaStreams, ) where @@ -129,7 +128,7 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async( let lwe_dimension = input.lwe_dimension(); let ct_modulus = input.ciphertext_modulus().raw_modulus_float(); - + let num_samples = input.lwe_ciphertext_count(); programmable_bootstrap_async( streams, &mut output.0.d_vec, @@ -159,7 +158,6 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async( input: &CudaLweCiphertextList, output: &mut CudaLweCiphertextList, accumulator: &CudaGlweCiphertextList, - num_samples: LweCiphertextCount, bsk: &CudaLweBootstrapKey, streams: &CudaStreams, ) where @@ -255,6 +253,7 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async( ); let lwe_dimension = input.lwe_dimension(); let ct_modulus = input.ciphertext_modulus().raw_modulus_float(); + let num_samples = input.lwe_ciphertext_count(); programmable_bootstrap_128_async( streams, &mut output.0.d_vec, @@ -280,7 +279,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext( lut_indexes: &CudaVec, output_indexes: &CudaVec, input_indexes: &CudaVec, - num_samples: LweCiphertextCount, bsk: &CudaLweBootstrapKey, streams: &CudaStreams, ) where @@ -294,7 +292,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext( lut_indexes, output_indexes, input_indexes, - num_samples, bsk, streams, ); @@ -318,7 +315,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext( input: &CudaLweCiphertextList, output: &mut CudaLweCiphertextList, accumulator: &CudaGlweCiphertextList, - num_samples: LweCiphertextCount, bsk: &CudaLweBootstrapKey, streams: &CudaStreams, ) where @@ -329,7 +325,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext( input, output, accumulator, - num_samples, bsk, streams, ); diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs index bbae07142..c02032334 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs @@ -153,7 +153,6 @@ fn lwe_encrypt_pbs_decrypt< &d_test_vector_indexes, &d_output_indexes, &d_input_indexes, - LweCiphertextCount(num_blocks), &d_bsk, &stream, ); diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs index fa258970c..0f7a7ecbd 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs @@ -180,12 +180,10 @@ pub fn execute_bootstrap_u128( let d_accumulator = CudaGlweCiphertextList::from_glwe_ciphertext(&accumulator, &stream); - let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0; cuda_programmable_bootstrap_128_lwe_ciphertext( &d_lwe_ciphertext_in, &mut d_out_pbs_ct, &d_accumulator, - LweCiphertextCount(num_blocks), &d_bsk, &stream, ); diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/noise_distribution/lwe_programmable_bootstrapping_noise.rs b/tfhe/src/core_crypto/gpu/algorithms/test/noise_distribution/lwe_programmable_bootstrapping_noise.rs index e7de4ada4..2b4505fe9 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/noise_distribution/lwe_programmable_bootstrapping_noise.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/noise_distribution/lwe_programmable_bootstrapping_noise.rs @@ -181,7 +181,6 @@ where &d_test_vector_indexes, &d_output_indexes, &d_input_indexes, - LweCiphertextCount(num_blocks), &d_bsk, &stream, ); diff --git a/tfhe/src/integer/gpu/server_key/radix/oprf.rs b/tfhe/src/integer/gpu/server_key/radix/oprf.rs index 5471771f1..b254c8569 100644 --- a/tfhe/src/integer/gpu/server_key/radix/oprf.rs +++ b/tfhe/src/integer/gpu/server_key/radix/oprf.rs @@ -25,7 +25,7 @@ use crate::core_crypto::commons::numeric::Numeric; use crate::core_crypto::gpu::add_lwe_ciphertext_vector_plaintext_scalar_async; use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList; use crate::core_crypto::prelude::CastInto; -use crate::integer::gpu::server_key::radix::{CudaLweCiphertextList, LweCiphertextCount}; +use crate::integer::gpu::server_key::radix::CudaLweCiphertextList; use crate::integer::gpu::CudaVec; use itertools::Itertools; @@ -574,7 +574,6 @@ impl CudaServerKey { &d_lut_vector_indexes, &d_output_indexes, &d_input_indexes, - LweCiphertextCount(num_ct_blocks), d_bsk, streams, );