From 7986e0bf1d33898953f1c5ca7ad92f4ceeadaaf0 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Thu, 12 Jun 2025 09:41:17 +0200 Subject: [PATCH] chore(gpu): skip packing ks test if it needs more ram than available --- .../benches/core_crypto/ks_bench.rs | 53 ++++++++++++++++--- tfhe/src/core_crypto/gpu/mod.rs | 31 +++++++++++ 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/tfhe-benchmark/benches/core_crypto/ks_bench.rs b/tfhe-benchmark/benches/core_crypto/ks_bench.rs index cf34069e5..316427d2f 100644 --- a/tfhe-benchmark/benches/core_crypto/ks_bench.rs +++ b/tfhe-benchmark/benches/core_crypto/ks_bench.rs @@ -331,10 +331,13 @@ mod cuda { use serde::Serialize; use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList; use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; + use tfhe::core_crypto::gpu::vec::GpuIndex; use tfhe::core_crypto::gpu::{ - cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64, - get_number_of_gpus, CudaStreams, + check_valid_cuda_malloc, cuda_keyswitch_lwe_ciphertext, + cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64, get_number_of_gpus, + get_packing_keyswitch_list_64_size_on_gpu, CudaStreams, }; + use tfhe::core_crypto::prelude::*; fn cuda_keyswitch + CastFrom + Serialize>( @@ -588,10 +591,24 @@ mod cuda { let cpu_keys: CpuKeys<_> = CpuKeysBuilder::new().packing_keyswitch_key(pksk).build(); let bench_id; - match get_bench_type() { BenchmarkType::Latency => { let streams = CudaStreams::new_multi_gpu(); + + let mem_size = get_packing_keyswitch_list_64_size_on_gpu( + &streams, + lwe_sk.lwe_dimension(), + glwe_sk.glwe_dimension(), + glwe_sk.polynomial_size(), + LweCiphertextCount(glwe_sk.polynomial_size().0), + ); + + let skip_bench = !check_valid_cuda_malloc(mem_size, GpuIndex::new(0)); + + if skip_bench { + continue; + } + let gpu_keys = CudaLocalKeys::from_cpu_keys(&cpu_keys, None, &streams); let mut input_ct_list = LweCiphertextList::new( @@ -647,9 +664,30 @@ mod cuda { let gpu_count = get_number_of_gpus() as usize; bench_id = format!("{bench_name}::throughput::{name}"); + + let mem_size = get_packing_keyswitch_list_64_size_on_gpu( + &CudaStreams::new_single_gpu(GpuIndex::new(0)), + lwe_sk.lwe_dimension(), + glwe_sk.glwe_dimension(), + glwe_sk.polynomial_size(), + LweCiphertextCount(glwe_sk.polynomial_size().0), + ); + + let mut skip_test = false; + for gpu_index in 0..gpu_count { + if !check_valid_cuda_malloc(mem_size, GpuIndex::new(gpu_index as u32)) { + skip_test = true; + } + } + + if skip_test { + continue; + } + let blocks: usize = 1; let elements = throughput_num_threads(blocks, 1); - let elements_per_stream = elements as usize / gpu_count; + let elements_per_stream = + std::cmp::min(elements as usize / gpu_count, glwe_sk.polynomial_size().0); bench_group.throughput(Throughput::Elements(elements)); bench_group.sample_size(50); bench_group.bench_function(&bench_id, |b| { @@ -666,7 +704,7 @@ mod cuda { let mut input_ct_list = LweCiphertextList::new( Scalar::ZERO, lwe_sk.lwe_dimension().to_lwe_size(), - LweCiphertextCount(glwe_sk.polynomial_size().0), + LweCiphertextCount(elements_per_stream), ciphertext_modulus, ); encrypt_lwe_ciphertext_list( @@ -743,8 +781,9 @@ mod cuda { } pub fn cuda_ks_group() { - let mut criterion: Criterion<_> = - (Criterion::default().sample_size(2000)).configure_from_args(); + let mut criterion: Criterion<_> = (Criterion::default().sample_size(15)) + .measurement_time(std::time::Duration::from_secs(60)) + .configure_from_args(); cuda_keyswitch(&mut criterion, &benchmark_parameters()); cuda_packing_keyswitch(&mut criterion, &benchmark_parameters()); } diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index db7f38478..d02de7151 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -1062,6 +1062,37 @@ pub fn is_cuda_available() -> bool { result == 1u32 } +pub fn get_packing_keyswitch_list_64_size_on_gpu( + streams: &CudaStreams, + input_lwe_dimension: LweDimension, + output_glwe_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + num_lwes: LweCiphertextCount, +) -> u64 { + let mut fp_ks_buffer: *mut i8 = std::ptr::null_mut(); + let size_tracker = unsafe { + scratch_packing_keyswitch_lwe_list_to_glwe_64( + streams.ptr[0], + streams.gpu_indexes[0].get(), + std::ptr::addr_of_mut!(fp_ks_buffer), + input_lwe_dimension.0 as u32, + output_glwe_dimension.0 as u32, + output_polynomial_size.0 as u32, + num_lwes.0 as u32, + false, + ) + }; + unsafe { + cleanup_packing_keyswitch_lwe_list_to_glwe( + streams.ptr[0], + streams.gpu_indexes[0].get(), + std::ptr::addr_of_mut!(fp_ks_buffer), + false, + ); + } + size_tracker +} + #[cfg(test)] mod tests { use super::*;