chore(gpu): no crash with aes benches if oom error

This commit is contained in:
Enzo Di Maria
2025-11-14 10:07:13 +01:00
committed by Agnès Leroy
parent 164fc26025
commit 54c8c5e020
2 changed files with 85 additions and 56 deletions

View File

@@ -3,7 +3,7 @@ pub mod cuda {
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use benchmark::utilities::{write_to_json, OperatorType};
use criterion::{black_box, Criterion};
use tfhe::core_crypto::gpu::CudaStreams;
use tfhe::core_crypto::gpu::{check_valid_cuda_malloc, CudaStreams};
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::keycache::KEY_CACHE;
@@ -102,6 +102,18 @@ pub mod cuda {
let sks = CudaServerKey::new(&cpu_cks, &streams);
let cks = RadixClientKey::from((cpu_cks, 1));
//
// Memory checks
//
let gpu_index = streams.gpu_indexes[0];
let key_expansion_size = sks.get_key_expansion_size_on_gpu(&streams);
let aes_encrypt_size =
sks.get_aes_encrypt_size_on_gpu(NUM_AES_INPUTS, SBOX_PARALLELISM, &streams);
if check_valid_cuda_malloc(key_expansion_size, gpu_index)
&& check_valid_cuda_malloc(aes_encrypt_size, gpu_index)
{
let ct_key = cks.encrypt_u128_for_aes_ctr(key);
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
@@ -133,6 +145,9 @@ pub mod cuda {
aes_op_bit_size,
vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize],
);
} else {
println!("{} skipped: Not enough memory in GPU", bench_id);
}
}
bench_group.finish();

View File

@@ -3,7 +3,7 @@ pub mod cuda {
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use benchmark::utilities::{write_to_json, OperatorType};
use criterion::{black_box, Criterion};
use tfhe::core_crypto::gpu::CudaStreams;
use tfhe::core_crypto::gpu::{check_valid_cuda_malloc, CudaStreams};
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::keycache::KEY_CACHE;
@@ -106,8 +106,19 @@ pub mod cuda {
let sks = CudaServerKey::new(&cpu_cks, &streams);
let cks = RadixClientKey::from((cpu_cks, 1));
let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo);
//
// Memory checks
//
let gpu_index = streams.gpu_indexes[0];
let key_expansion_size = sks.get_key_expansion_256_size_on_gpu(&streams);
let aes_encrypt_size =
sks.get_aes_encrypt_size_on_gpu(NUM_AES_INPUTS, SBOX_PARALLELISM, &streams);
if check_valid_cuda_malloc(key_expansion_size, gpu_index)
&& check_valid_cuda_malloc(aes_encrypt_size, gpu_index)
{
let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo);
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
@@ -138,6 +149,9 @@ pub mod cuda {
aes_block_op_bit_size,
vec![atomic_param.message_modulus().0.ilog2(); aes_block_op_bit_size as usize],
);
} else {
println!("{} skipped: Not enough memory in GPU", bench_id);
}
}
bench_group.finish();