mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
fix(gpu): Fix expand bench on multi-gpus
This commit is contained in:
@@ -421,8 +421,6 @@ mod cuda {
|
|||||||
.sample_size(15)
|
.sample_size(15)
|
||||||
.measurement_time(std::time::Duration::from_secs(60));
|
.measurement_time(std::time::Duration::from_secs(60));
|
||||||
|
|
||||||
let streams = CudaStreams::new_multi_gpu();
|
|
||||||
|
|
||||||
File::create(results_file).expect("create results file failed");
|
File::create(results_file).expect("create results file failed");
|
||||||
let mut file = OpenOptions::new()
|
let mut file = OpenOptions::new()
|
||||||
.append(true)
|
.append(true)
|
||||||
@@ -439,17 +437,10 @@ mod cuda {
|
|||||||
let cks = ClientKey::new(param_fhe);
|
let cks = ClientKey::new(param_fhe);
|
||||||
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||||
let sk = compressed_server_key.decompress();
|
let sk = compressed_server_key.decompress();
|
||||||
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
|
|
||||||
|
|
||||||
let compact_private_key = CompactPrivateKey::new(param_pke);
|
let compact_private_key = CompactPrivateKey::new(param_pke);
|
||||||
let pk = CompactPublicKey::new(&compact_private_key);
|
let pk = CompactPublicKey::new(&compact_private_key);
|
||||||
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
|
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
|
||||||
let d_ksk_material =
|
|
||||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
|
||||||
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
|
||||||
&d_ksk_material,
|
|
||||||
&gpu_sks,
|
|
||||||
);
|
|
||||||
|
|
||||||
// We have a use case with 320 bits of metadata
|
// We have a use case with 320 bits of metadata
|
||||||
let mut metadata = [0u8; (320 / u8::BITS) as usize];
|
let mut metadata = [0u8; (320 / u8::BITS) as usize];
|
||||||
@@ -509,6 +500,18 @@ mod cuda {
|
|||||||
|
|
||||||
match get_bench_type() {
|
match get_bench_type() {
|
||||||
BenchmarkType::Latency => {
|
BenchmarkType::Latency => {
|
||||||
|
let streams = CudaStreams::new_multi_gpu();
|
||||||
|
let gpu_sks = CudaServerKey::decompress_from_cpu(
|
||||||
|
&compressed_server_key,
|
||||||
|
&streams,
|
||||||
|
);
|
||||||
|
let d_ksk_material =
|
||||||
|
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||||
|
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||||
|
&d_ksk_material,
|
||||||
|
&gpu_sks,
|
||||||
|
);
|
||||||
|
|
||||||
bench_id_verify = format!(
|
bench_id_verify = format!(
|
||||||
"{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
|
"{bench_name}::{param_name}_{bits}_bits_packed_{zk_load}_ZK{zk_vers:?}"
|
||||||
);
|
);
|
||||||
@@ -599,9 +602,7 @@ mod cuda {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
BenchmarkType::Throughput => {
|
BenchmarkType::Throughput => {
|
||||||
let gpu_count = get_number_of_gpus() as usize;
|
let elements = 100 * get_number_of_gpus() as u64; // This value, found empirically, ensure saturation of 8XH100 SXM5
|
||||||
|
|
||||||
let elements = zk_throughput_num_elements();
|
|
||||||
bench_group.throughput(Throughput::Elements(elements));
|
bench_group.throughput(Throughput::Elements(elements));
|
||||||
|
|
||||||
bench_id_verify = format!(
|
bench_id_verify = format!(
|
||||||
@@ -636,8 +637,6 @@ mod cuda {
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
assert_eq!(d_ksk_material_vec.len(), gpu_count);
|
|
||||||
|
|
||||||
bench_group.bench_function(&bench_id_verify, |b| {
|
bench_group.bench_function(&bench_id_verify, |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
cts.par_iter().for_each(|ct1| {
|
cts.par_iter().for_each(|ct1| {
|
||||||
@@ -648,23 +647,25 @@ mod cuda {
|
|||||||
|
|
||||||
bench_group.bench_function(&bench_id_expand_without_verify, |b| {
|
bench_group.bench_function(&bench_id_expand_without_verify, |b| {
|
||||||
let setup_encrypted_values = || {
|
let setup_encrypted_values = || {
|
||||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
|
||||||
|
|
||||||
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
||||||
|
let local_stream = &local_streams[i % local_streams.len()];
|
||||||
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||||
ct, &local_streams[i],
|
ct, local_stream,
|
||||||
)
|
)
|
||||||
}).collect_vec();
|
}).collect_vec();
|
||||||
|
|
||||||
(gpu_cts, local_streams)
|
gpu_cts
|
||||||
};
|
};
|
||||||
|
|
||||||
b.iter_batched(setup_encrypted_values,
|
b.iter_batched(setup_encrypted_values,
|
||||||
|(gpu_cts, local_streams)| {
|
|gpu_cts| {
|
||||||
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
|
gpu_cts.par_iter().enumerate().for_each
|
||||||
(|(i, (gpu_ct, local_stream))| {
|
(|(i, gpu_ct)| {
|
||||||
|
let local_stream = &local_streams[i % local_streams.len()];
|
||||||
|
|
||||||
|
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||||
let d_ksk =
|
let d_ksk =
|
||||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
|
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||||
|
|
||||||
gpu_ct
|
gpu_ct
|
||||||
.expand_without_verification(&d_ksk, local_stream)
|
.expand_without_verification(&d_ksk, local_stream)
|
||||||
@@ -675,21 +676,24 @@ mod cuda {
|
|||||||
|
|
||||||
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
|
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
|
||||||
let setup_encrypted_values = || {
|
let setup_encrypted_values = || {
|
||||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
|
||||||
|
|
||||||
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
|
||||||
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
|
||||||
ct, &local_streams[i],
|
ct, &local_streams[i% local_streams.len()],
|
||||||
)
|
)
|
||||||
}).collect_vec();
|
}).collect_vec();
|
||||||
|
|
||||||
(gpu_cts, local_streams)
|
gpu_cts
|
||||||
};
|
};
|
||||||
|
|
||||||
b.iter_batched(setup_encrypted_values,
|
b.iter_batched(setup_encrypted_values,
|
||||||
|(gpu_cts, local_streams)| {
|
|gpu_cts| {
|
||||||
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
|
gpu_cts.par_iter().enumerate().for_each
|
||||||
(|(gpu_ct, local_stream)| {
|
(|(i, gpu_ct)| {
|
||||||
|
let local_stream = &local_streams[i % local_streams.len()];
|
||||||
|
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||||
|
let d_ksk =
|
||||||
|
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||||
|
|
||||||
gpu_ct
|
gpu_ct
|
||||||
.verify_and_expand(
|
.verify_and_expand(
|
||||||
&crs, &pk, &metadata, &d_ksk, local_stream,
|
&crs, &pk, &metadata, &d_ksk, local_stream,
|
||||||
|
|||||||
@@ -8,12 +8,10 @@ use crate::core_crypto::prelude::{
|
|||||||
use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind};
|
use crate::integer::ciphertext::{CompactCiphertextListExpander, DataKind};
|
||||||
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
|
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaExpandable;
|
||||||
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
||||||
use crate::integer::gpu::ciphertext::{
|
use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaVec, KsType, LweDimension};
|
||||||
expand_async, CudaRadixCiphertext, CudaVec, KsType, LweDimension,
|
|
||||||
};
|
|
||||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKey;
|
||||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||||
use crate::integer::gpu::PBSType;
|
use crate::integer::gpu::{expand_async, PBSType};
|
||||||
use crate::shortint::ciphertext::CompactCiphertextList;
|
use crate::shortint::ciphertext::CompactCiphertextList;
|
||||||
use crate::shortint::parameters::{
|
use crate::shortint::parameters::{
|
||||||
CompactCiphertextListExpansionKind, Degree, LweBskGroupingFactor, NoiseLevel,
|
CompactCiphertextListExpansionKind, Degree, LweBskGroupingFactor, NoiseLevel,
|
||||||
|
|||||||
@@ -4,27 +4,15 @@ pub mod compressed_ciphertext_list;
|
|||||||
pub mod info;
|
pub mod info;
|
||||||
pub mod squashed_noise;
|
pub mod squashed_noise;
|
||||||
|
|
||||||
use crate::core_crypto::gpu::lwe_bootstrap_key::{
|
|
||||||
prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey,
|
|
||||||
};
|
|
||||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||||
use crate::core_crypto::gpu::vec::CudaVec;
|
use crate::core_crypto::gpu::vec::CudaVec;
|
||||||
use crate::core_crypto::gpu::CudaStreams;
|
use crate::core_crypto::gpu::CudaStreams;
|
||||||
use crate::core_crypto::prelude::{
|
use crate::core_crypto::prelude::{LweCiphertextList, LweCiphertextOwned};
|
||||||
LweBskGroupingFactor, LweCiphertextList, LweCiphertextOwned, Numeric, UnsignedInteger,
|
|
||||||
};
|
|
||||||
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
|
||||||
use crate::integer::gpu::PBSType;
|
use crate::integer::parameters::LweDimension;
|
||||||
use crate::integer::parameters::{
|
|
||||||
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
|
|
||||||
};
|
|
||||||
use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
|
use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
|
||||||
use crate::shortint::{CarryModulus, Ciphertext, EncryptionKeyChoice, MessageModulus};
|
use crate::shortint::{Ciphertext, EncryptionKeyChoice};
|
||||||
use crate::GpuIndex;
|
use crate::GpuIndex;
|
||||||
use tfhe_cuda_backend::bindings::{
|
|
||||||
cleanup_expand_without_verification_64, cuda_expand_without_verification_64,
|
|
||||||
scratch_cuda_expand_without_verification_64,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait CudaIntegerRadixCiphertext: Sized {
|
pub trait CudaIntegerRadixCiphertext: Sized {
|
||||||
const IS_SIGNED: bool;
|
const IS_SIGNED: bool;
|
||||||
@@ -527,100 +515,3 @@ impl From<EncryptionKeyChoice> for KsType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
|
||||||
/// be dropped until stream is synchronised
|
|
||||||
///
|
|
||||||
///
|
|
||||||
/// In this method, the input `lwe_flattened_compact_array_in` represents a flattened compact list.
|
|
||||||
/// Instead of receiving a `Vec<CompactCiphertextList>`, it takes a concatenation of all LWEs
|
|
||||||
/// that were inside that vector of compact list. Handling the input this way removes the need
|
|
||||||
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
|
|
||||||
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
|
|
||||||
pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
|
|
||||||
streams: &CudaStreams,
|
|
||||||
lwe_array_out: &mut CudaLweCiphertextList<T>,
|
|
||||||
lwe_flattened_compact_array_in: &CudaVec<T>,
|
|
||||||
bootstrapping_key: &CudaVec<B>,
|
|
||||||
computing_ks_key: &CudaVec<T>,
|
|
||||||
casting_key: &CudaVec<T>,
|
|
||||||
message_modulus: MessageModulus,
|
|
||||||
carry_modulus: CarryModulus,
|
|
||||||
computing_glwe_dimension: GlweDimension,
|
|
||||||
computing_polynomial_size: PolynomialSize,
|
|
||||||
computing_lwe_dimension: LweDimension,
|
|
||||||
computing_ks_level: DecompositionLevelCount,
|
|
||||||
computing_ks_base_log: DecompositionBaseLog,
|
|
||||||
casting_input_lwe_dimension: LweDimension,
|
|
||||||
casting_output_lwe_dimension: LweDimension,
|
|
||||||
casting_ks_level: DecompositionLevelCount,
|
|
||||||
casting_ks_base_log: DecompositionBaseLog,
|
|
||||||
pbs_level: DecompositionLevelCount,
|
|
||||||
pbs_base_log: DecompositionBaseLog,
|
|
||||||
pbs_type: PBSType,
|
|
||||||
casting_key_type: KsType,
|
|
||||||
grouping_factor: LweBskGroupingFactor,
|
|
||||||
num_lwes_per_compact_list: &[u32],
|
|
||||||
is_boolean: &[bool],
|
|
||||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
|
||||||
) {
|
|
||||||
let ct_modulus = lwe_array_out.ciphertext_modulus().raw_modulus_float();
|
|
||||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
|
||||||
let num_compact_lists = num_lwes_per_compact_list.len();
|
|
||||||
|
|
||||||
let ms_noise_reduction_key_ffi =
|
|
||||||
prepare_cuda_ms_noise_reduction_key_ffi(noise_reduction_key, ct_modulus);
|
|
||||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
|
||||||
|
|
||||||
scratch_cuda_expand_without_verification_64(
|
|
||||||
streams.ptr.as_ptr(),
|
|
||||||
streams.gpu_indexes_ptr(),
|
|
||||||
streams.len() as u32,
|
|
||||||
std::ptr::addr_of_mut!(mem_ptr),
|
|
||||||
computing_glwe_dimension.0 as u32,
|
|
||||||
computing_polynomial_size.0 as u32,
|
|
||||||
computing_glwe_dimension
|
|
||||||
.to_equivalent_lwe_dimension(computing_polynomial_size)
|
|
||||||
.0 as u32,
|
|
||||||
computing_lwe_dimension.0 as u32,
|
|
||||||
computing_ks_level.0 as u32,
|
|
||||||
computing_ks_base_log.0 as u32,
|
|
||||||
casting_input_lwe_dimension.0 as u32,
|
|
||||||
casting_output_lwe_dimension.0 as u32,
|
|
||||||
casting_ks_level.0 as u32,
|
|
||||||
casting_ks_base_log.0 as u32,
|
|
||||||
pbs_level.0 as u32,
|
|
||||||
pbs_base_log.0 as u32,
|
|
||||||
grouping_factor.0 as u32,
|
|
||||||
num_lwes_per_compact_list.as_ptr(),
|
|
||||||
is_boolean.as_ptr(),
|
|
||||||
num_compact_lists as u32,
|
|
||||||
message_modulus.0 as u32,
|
|
||||||
carry_modulus.0 as u32,
|
|
||||||
pbs_type as u32,
|
|
||||||
casting_key_type as u32,
|
|
||||||
true,
|
|
||||||
allocate_ms_noise_array,
|
|
||||||
);
|
|
||||||
cuda_expand_without_verification_64(
|
|
||||||
streams.ptr.as_ptr(),
|
|
||||||
streams.gpu_indexes_ptr(),
|
|
||||||
streams.len() as u32,
|
|
||||||
lwe_array_out.0.d_vec.as_mut_c_ptr(0),
|
|
||||||
lwe_flattened_compact_array_in.as_c_ptr(0),
|
|
||||||
mem_ptr,
|
|
||||||
bootstrapping_key.ptr.as_ptr(),
|
|
||||||
computing_ks_key.ptr.as_ptr(),
|
|
||||||
casting_key.ptr.as_ptr(),
|
|
||||||
&raw const ms_noise_reduction_key_ffi,
|
|
||||||
);
|
|
||||||
cleanup_expand_without_verification_64(
|
|
||||||
streams.ptr.as_ptr(),
|
|
||||||
streams.gpu_indexes_ptr(),
|
|
||||||
streams.len() as u32,
|
|
||||||
std::ptr::addr_of_mut!(mem_ptr),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ pub mod zk;
|
|||||||
use crate::core_crypto::gpu::lwe_bootstrap_key::{
|
use crate::core_crypto::gpu::lwe_bootstrap_key::{
|
||||||
prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey,
|
prepare_cuda_ms_noise_reduction_key_ffi, CudaModulusSwitchNoiseReductionKey,
|
||||||
};
|
};
|
||||||
|
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||||
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
||||||
use crate::core_crypto::gpu::vec::CudaVec;
|
use crate::core_crypto::gpu::vec::CudaVec;
|
||||||
use crate::core_crypto::gpu::CudaStreams;
|
use crate::core_crypto::gpu::CudaStreams;
|
||||||
@@ -19,7 +20,7 @@ use crate::core_crypto::prelude::{
|
|||||||
};
|
};
|
||||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||||
use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
|
use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, KsType};
|
||||||
use crate::integer::server_key::radix_parallel::OutputFlag;
|
use crate::integer::server_key::radix_parallel::OutputFlag;
|
||||||
use crate::integer::server_key::ScalarMultiplier;
|
use crate::integer::server_key::ScalarMultiplier;
|
||||||
use crate::integer::{ClientKey, RadixClientKey};
|
use crate::integer::{ClientKey, RadixClientKey};
|
||||||
@@ -6643,3 +6644,135 @@ pub unsafe fn noise_squashing_async<T: UnsignedInteger, B: Numeric>(
|
|||||||
std::ptr::addr_of_mut!(mem_ptr),
|
std::ptr::addr_of_mut!(mem_ptr),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||||
|
/// be dropped until stream is synchronised
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// In this method, the input `lwe_flattened_compact_array_in` represents a flattened compact list.
|
||||||
|
/// Instead of receiving a `Vec<CompactCiphertextList>`, it takes a concatenation of all LWEs
|
||||||
|
/// that were inside that vector of compact list. Handling the input this way removes the need
|
||||||
|
/// to process multiple compact lists separately, simplifying GPU-based operations. The variable
|
||||||
|
/// name `lwe_flattened_compact_array_in` makes this intent explicit.
|
||||||
|
pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
|
||||||
|
streams: &CudaStreams,
|
||||||
|
lwe_array_out: &mut CudaLweCiphertextList<T>,
|
||||||
|
lwe_flattened_compact_array_in: &CudaVec<T>,
|
||||||
|
bootstrapping_key: &CudaVec<B>,
|
||||||
|
computing_ks_key: &CudaVec<T>,
|
||||||
|
casting_key: &CudaVec<T>,
|
||||||
|
message_modulus: MessageModulus,
|
||||||
|
carry_modulus: CarryModulus,
|
||||||
|
computing_glwe_dimension: GlweDimension,
|
||||||
|
computing_polynomial_size: PolynomialSize,
|
||||||
|
computing_lwe_dimension: LweDimension,
|
||||||
|
computing_ks_level: DecompositionLevelCount,
|
||||||
|
computing_ks_base_log: DecompositionBaseLog,
|
||||||
|
casting_input_lwe_dimension: LweDimension,
|
||||||
|
casting_output_lwe_dimension: LweDimension,
|
||||||
|
casting_ks_level: DecompositionLevelCount,
|
||||||
|
casting_ks_base_log: DecompositionBaseLog,
|
||||||
|
pbs_level: DecompositionLevelCount,
|
||||||
|
pbs_base_log: DecompositionBaseLog,
|
||||||
|
pbs_type: PBSType,
|
||||||
|
casting_key_type: KsType,
|
||||||
|
grouping_factor: LweBskGroupingFactor,
|
||||||
|
num_lwes_per_compact_list: &[u32],
|
||||||
|
is_boolean: &[bool],
|
||||||
|
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||||
|
) {
|
||||||
|
assert_eq!(
|
||||||
|
streams.gpu_indexes[0],
|
||||||
|
lwe_array_out.0.d_vec.gpu_index(0),
|
||||||
|
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
lwe_array_out.0.d_vec.gpu_index(0).get(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
streams.gpu_indexes[0],
|
||||||
|
lwe_flattened_compact_array_in.gpu_index(0),
|
||||||
|
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
lwe_flattened_compact_array_in.gpu_index(0).get(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
streams.gpu_indexes[0],
|
||||||
|
bootstrapping_key.gpu_indexes[0],
|
||||||
|
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
bootstrapping_key.gpu_indexes[0].get(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
streams.gpu_indexes[0],
|
||||||
|
computing_ks_key.gpu_indexes[0],
|
||||||
|
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
computing_ks_key.gpu_indexes[0].get(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
streams.gpu_indexes[0],
|
||||||
|
casting_key.gpu_indexes[0],
|
||||||
|
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||||
|
streams.gpu_indexes[0].get(),
|
||||||
|
casting_key.gpu_indexes[0].get(),
|
||||||
|
);
|
||||||
|
let ct_modulus = lwe_array_out.ciphertext_modulus().raw_modulus_float();
|
||||||
|
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||||
|
let num_compact_lists = num_lwes_per_compact_list.len();
|
||||||
|
|
||||||
|
let ms_noise_reduction_key_ffi =
|
||||||
|
prepare_cuda_ms_noise_reduction_key_ffi(noise_reduction_key, ct_modulus);
|
||||||
|
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||||
|
|
||||||
|
scratch_cuda_expand_without_verification_64(
|
||||||
|
streams.ptr.as_ptr(),
|
||||||
|
streams.gpu_indexes_ptr(),
|
||||||
|
streams.len() as u32,
|
||||||
|
std::ptr::addr_of_mut!(mem_ptr),
|
||||||
|
computing_glwe_dimension.0 as u32,
|
||||||
|
computing_polynomial_size.0 as u32,
|
||||||
|
computing_glwe_dimension
|
||||||
|
.to_equivalent_lwe_dimension(computing_polynomial_size)
|
||||||
|
.0 as u32,
|
||||||
|
computing_lwe_dimension.0 as u32,
|
||||||
|
computing_ks_level.0 as u32,
|
||||||
|
computing_ks_base_log.0 as u32,
|
||||||
|
casting_input_lwe_dimension.0 as u32,
|
||||||
|
casting_output_lwe_dimension.0 as u32,
|
||||||
|
casting_ks_level.0 as u32,
|
||||||
|
casting_ks_base_log.0 as u32,
|
||||||
|
pbs_level.0 as u32,
|
||||||
|
pbs_base_log.0 as u32,
|
||||||
|
grouping_factor.0 as u32,
|
||||||
|
num_lwes_per_compact_list.as_ptr(),
|
||||||
|
is_boolean.as_ptr(),
|
||||||
|
num_compact_lists as u32,
|
||||||
|
message_modulus.0 as u32,
|
||||||
|
carry_modulus.0 as u32,
|
||||||
|
pbs_type as u32,
|
||||||
|
casting_key_type as u32,
|
||||||
|
true,
|
||||||
|
allocate_ms_noise_array,
|
||||||
|
);
|
||||||
|
cuda_expand_without_verification_64(
|
||||||
|
streams.ptr.as_ptr(),
|
||||||
|
streams.gpu_indexes_ptr(),
|
||||||
|
streams.len() as u32,
|
||||||
|
lwe_array_out.0.d_vec.as_mut_c_ptr(0),
|
||||||
|
lwe_flattened_compact_array_in.as_c_ptr(0),
|
||||||
|
mem_ptr,
|
||||||
|
bootstrapping_key.ptr.as_ptr(),
|
||||||
|
computing_ks_key.ptr.as_ptr(),
|
||||||
|
casting_key.ptr.as_ptr(),
|
||||||
|
&raw const ms_noise_reduction_key_ffi,
|
||||||
|
);
|
||||||
|
cleanup_expand_without_verification_64(
|
||||||
|
streams.ptr.as_ptr(),
|
||||||
|
streams.gpu_indexes_ptr(),
|
||||||
|
streams.len() as u32,
|
||||||
|
std::ptr::addr_of_mut!(mem_ptr),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user