fix(gpu): Fix expand bench on multi-gpus

This commit is contained in:
Pedro Alves
2025-07-07 12:46:08 -03:00
committed by Agnès Leroy
parent 776f08b534
commit 9960f5e8b6
4 changed files with 172 additions and 146 deletions

View File

@@ -421,8 +421,6 @@ mod cuda {
.sample_size(15) .sample_size(15)
.measurement_time(std::time::Duration::from_secs(60)); .measurement_time(std::time::Duration::from_secs(60));
let streams = CudaStreams::new_multi_gpu();
File::create(results_file).expect("create results file failed"); File::create(results_file).expect("create results file failed");
let mut file = OpenOptions::new() let mut file = OpenOptions::new()
.append(true) .append(true)
@@ -439,17 +437,10 @@ mod cuda {
let cks = ClientKey::new(param_fhe); let cks = ClientKey::new(param_fhe);
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks); let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
let sk = compressed_server_key.decompress(); let sk = compressed_server_key.decompress();
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
let compact_private_key = CompactPrivateKey::new(param_pke); let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key); let pk = CompactPublicKey::new(&compact_private_key);
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk); let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);
// We have a use case with 320 bits of metadata // We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize]; let mut metadata = [0u8; (320 / u8::BITS) as usize];
@@ -509,6 +500,18 @@ mod cuda {
match get_bench_type() { match get_bench_type() {
BenchmarkType::Latency => { BenchmarkType::Latency => {
let streams = CudaStreams::new_multi_gpu();
let gpu_sks = CudaServerKey::decompress_from_cpu(
&compressed_server_key,
&streams,
);
let d_ksk_material =
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material,
&gpu_sks,
);
bench_id_verify = format!( bench_id_verify = format!(
"{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}" "{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
); );
@@ -599,9 +602,7 @@ mod cuda {
}); });
} }
BenchmarkType::Throughput => { BenchmarkType::Throughput => {
let gpu_count = get_number_of_gpus() as usize; let elements = 100 * get_number_of_gpus() as u64; // This value, found empirically, ensure saturation of 8XH100 SXM5
let elements = zk_throughput_num_elements();
bench_group.throughput(Throughput::Elements(elements)); bench_group.throughput(Throughput::Elements(elements));
bench_id_verify = format!( bench_id_verify = format!(
@@ -636,8 +637,6 @@ mod cuda {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(d_ksk_material_vec.len(), gpu_count);
bench_group.bench_function(&bench_id_verify, |b| { bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| { b.iter(|| {
cts.par_iter().for_each(|ct1| { cts.par_iter().for_each(|ct1| {
@@ -648,23 +647,25 @@ mod cuda {
bench_group.bench_function(&bench_id_expand_without_verify, |b| { bench_group.bench_function(&bench_id_expand_without_verify, |b| {
let setup_encrypted_values = || { let setup_encrypted_values = || {
let local_streams = cuda_local_streams(num_block, elements as usize);
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| { let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
let local_stream = &local_streams[i % local_streams.len()];
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list( CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
ct, &local_streams[i], ct, local_stream,
) )
}).collect_vec(); }).collect_vec();
(gpu_cts, local_streams) gpu_cts
}; };
b.iter_batched(setup_encrypted_values, b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| { |gpu_cts| {
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each gpu_cts.par_iter().enumerate().for_each
(|(i, (gpu_ct, local_stream))| { (|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk = let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks); CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
gpu_ct gpu_ct
.expand_without_verification(&d_ksk, local_stream) .expand_without_verification(&d_ksk, local_stream)
@@ -675,21 +676,24 @@ mod cuda {
bench_group.bench_function(&bench_id_verify_and_expand, |b| { bench_group.bench_function(&bench_id_verify_and_expand, |b| {
let setup_encrypted_values = || { let setup_encrypted_values = || {
let local_streams = cuda_local_streams(num_block, elements as usize);
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| { let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list( CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
ct, &local_streams[i], ct, &local_streams[i% local_streams.len()],
) )
}).collect_vec(); }).collect_vec();
(gpu_cts, local_streams) gpu_cts
}; };
b.iter_batched(setup_encrypted_values, b.iter_batched(setup_encrypted_values,
|(gpu_cts, local_streams)| { |gpu_cts| {
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each gpu_cts.par_iter().enumerate().for_each
(|(gpu_ct, local_stream)| { (|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
gpu_ct gpu_ct
.verify_and_expand( .verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream, &crs, &pk, &metadata, &d_ksk, local_stream,

View File

@@ -8,12 +8,10 @@ use crate::core_crypto::prelude::{
use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind}; use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind};
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable; use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
use crate::integer::gpu::ciphertext::{ use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaVec, KsType, LweDimension};
expand_async, CudaRadixCiphertext, CudaVec, KsType, LweDimension,
};
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey; use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::PBSType; use crate::integer::gpu::{expand_async, PBSType};
use crate::shortint::ciphertext::CompactCiphertextList; use crate::shortint::ciphertext::CompactCiphertextList;
use crate::shortint::parameters::{ use crate::shortint::parameters::{
CompactCiphertextListExpansionKind, Degree, LweBskGroupingFactor, NoiseLevel, CompactCiphertextListExpansionKind, Degree, LweBskGroupingFactor, NoiseLevel,

View File

@@ -4,27 +4,15 @@ pub mod compressed_ciphertext_list;
pub mod info; pub mod info;
pub mod squashed_noise; pub mod squashed_noise;
use crate::core_crypto::gpu::lwe_bootstrap_key::{
prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey,
};
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{ use crate::core_crypto::prelude::{LweCiphertextList, LweCiphertextOwned};
LweBskGroupingFactor, LweCiphertextList, LweCiphertextOwned, Numeric, UnsignedInteger,
};
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
use crate::integer::gpu::PBSType; use crate::integer::parameters::LweDimension;
use crate::integer::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
};
use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext}; use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
use crate::shortint::{CarryModulus, Ciphertext, EncryptionKeyChoice, MessageModulus}; use crate::shortint::{Ciphertext, EncryptionKeyChoice};
use crate::GpuIndex; use crate::GpuIndex;
use tfhe_cuda_backend::bindings::{
cleanup_expand_without_verification_64, cuda_expand_without_verification_64,
scratch_cuda_expand_without_verification_64,
};
pub trait CudaIntegerRadixCiphertext: Sized { pub trait CudaIntegerRadixCiphertext: Sized {
const IS_SIGNED: bool; const IS_SIGNED: bool;
@@ -527,100 +515,3 @@ impl From<EncryptionKeyChoice> for KsType {
} }
} }
} }
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
///
///
/// In this method, the input `lwe_flattened_compact_array_in` represents a flattened compact list.
/// Instead of receiving a `Vec<CompactCiphertextList>`, it takes a concatenation of all LWEs
/// that were inside that vector of compact list. Handling the input this way removes the need
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
lwe_array_out: &mut CudaLweCiphertextList<T>,
lwe_flattened_compact_array_in: &CudaVec<T>,
bootstrapping_key: &CudaVec<B>,
computing_ks_key: &CudaVec<T>,
casting_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
computing_glwe_dimension: GlweDimension,
computing_polynomial_size: PolynomialSize,
computing_lwe_dimension: LweDimension,
computing_ks_level: DecompositionLevelCount,
computing_ks_base_log: DecompositionBaseLog,
casting_input_lwe_dimension: LweDimension,
casting_output_lwe_dimension: LweDimension,
casting_ks_level: DecompositionLevelCount,
casting_ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
casting_key_type: KsType,
grouping_factor: LweBskGroupingFactor,
num_lwes_per_compact_list: &[u32],
is_boolean: &[bool],
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) {
let ct_modulus = lwe_array_out.ciphertext_modulus().raw_modulus_float();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let num_compact_lists = num_lwes_per_compact_list.len();
let ms_noise_reduction_key_ffi =
prepare_cuda_ms_noise_reduction_key_ffi(noise_reduction_key, ct_modulus);
let allocate_ms_noise_array = noise_reduction_key.is_some();
scratch_cuda_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
computing_glwe_dimension.0 as u32,
computing_polynomial_size.0 as u32,
computing_glwe_dimension
.to_equivalent_lwe_dimension(computing_polynomial_size)
.0 as u32,
computing_lwe_dimension.0 as u32,
computing_ks_level.0 as u32,
computing_ks_base_log.0 as u32,
casting_input_lwe_dimension.0 as u32,
casting_output_lwe_dimension.0 as u32,
casting_ks_level.0 as u32,
casting_ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_lwes_per_compact_list.as_ptr(),
is_boolean.as_ptr(),
num_compact_lists as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
casting_key_type as u32,
true,
allocate_ms_noise_array,
);
cuda_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
lwe_array_out.0.d_vec.as_mut_c_ptr(0),
lwe_flattened_compact_array_in.as_c_ptr(0),
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
computing_ks_key.ptr.as_ptr(),
casting_key.ptr.as_ptr(),
&raw const ms_noise_reduction_key_ffi,
);
cleanup_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}

View File

@@ -10,6 +10,7 @@ pub mod zk;
use crate::core_crypto::gpu::lwe_bootstrap_key::{ use crate::core_crypto::gpu::lwe_bootstrap_key::{
prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey, prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey,
}; };
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut}; use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::gpu::CudaStreams;
@@ -19,7 +20,7 @@ use crate::core_crypto::prelude::{
}; };
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::CudaRadixCiphertext; use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, KsType};
use crate::integer::server_key::radix_parallel::OutputFlag; use crate::integer::server_key::radix_parallel::OutputFlag;
use crate::integer::server_key::ScalarMultiplier; use crate::integer::server_key::ScalarMultiplier;
use crate::integer::{ClientKey, RadixClientKey}; use crate::integer::{ClientKey, RadixClientKey};
@@ -6643,3 +6644,135 @@ pub unsafe fn noise_squashing_async<T: UnsignedInteger, B: Numeric>(
std::ptr::addr_of_mut!(mem_ptr), std::ptr::addr_of_mut!(mem_ptr),
); );
} }
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
///
///
/// In this method, the input `lwe_flattened_compact_array_in` represents a flattened compact list.
/// Instead of receiving a `Vec<CompactCiphertextList>`, it takes a concatenation of all LWEs
/// that were inside that vector of compact list. Handling the input this way removes the need
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
lwe_array_out: &mut CudaLweCiphertextList<T>,
lwe_flattened_compact_array_in: &CudaVec<T>,
bootstrapping_key: &CudaVec<B>,
computing_ks_key: &CudaVec<T>,
casting_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
computing_glwe_dimension: GlweDimension,
computing_polynomial_size: PolynomialSize,
computing_lwe_dimension: LweDimension,
computing_ks_level: DecompositionLevelCount,
computing_ks_base_log: DecompositionBaseLog,
casting_input_lwe_dimension: LweDimension,
casting_output_lwe_dimension: LweDimension,
casting_ks_level: DecompositionLevelCount,
casting_ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
casting_key_type: KsType,
grouping_factor: LweBskGroupingFactor,
num_lwes_per_compact_list: &[u32],
is_boolean: &[bool],
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
) {
assert_eq!(
streams.gpu_indexes[0],
lwe_array_out.0.d_vec.gpu_index(0),
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
lwe_array_out.0.d_vec.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
lwe_flattened_compact_array_in.gpu_index(0),
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
lwe_flattened_compact_array_in.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
bootstrapping_key.gpu_indexes[0],
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
bootstrapping_key.gpu_indexes[0].get(),
);
assert_eq!(
streams.gpu_indexes[0],
computing_ks_key.gpu_indexes[0],
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
computing_ks_key.gpu_indexes[0].get(),
);
assert_eq!(
streams.gpu_indexes[0],
casting_key.gpu_indexes[0],
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
casting_key.gpu_indexes[0].get(),
);
let ct_modulus = lwe_array_out.ciphertext_modulus().raw_modulus_float();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let num_compact_lists = num_lwes_per_compact_list.len();
let ms_noise_reduction_key_ffi =
prepare_cuda_ms_noise_reduction_key_ffi(noise_reduction_key, ct_modulus);
let allocate_ms_noise_array = noise_reduction_key.is_some();
scratch_cuda_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
computing_glwe_dimension.0 as u32,
computing_polynomial_size.0 as u32,
computing_glwe_dimension
.to_equivalent_lwe_dimension(computing_polynomial_size)
.0 as u32,
computing_lwe_dimension.0 as u32,
computing_ks_level.0 as u32,
computing_ks_base_log.0 as u32,
casting_input_lwe_dimension.0 as u32,
casting_output_lwe_dimension.0 as u32,
casting_ks_level.0 as u32,
casting_ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_lwes_per_compact_list.as_ptr(),
is_boolean.as_ptr(),
num_compact_lists as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
casting_key_type as u32,
true,
allocate_ms_noise_array,
);
cuda_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
lwe_array_out.0.d_vec.as_mut_c_ptr(0),
lwe_flattened_compact_array_in.as_c_ptr(0),
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
computing_ks_key.ptr.as_ptr(),
casting_key.ptr.as_ptr(),
&raw const ms_noise_reduction_key_ffi,
);
cleanup_expand_without_verification_64(
streams.ptr.as_ptr(),
streams.gpu_indexes_ptr(),
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
}