chore(gpu): refactor the gpu oom checker

This commit is contained in:
Andrei Stoian
2025-06-20 14:39:13 +02:00
committed by Andrei Stoian
parent 981083360e
commit d0743e9d3d
10 changed files with 158 additions and 438 deletions

View File

@@ -55,6 +55,7 @@ void *cuda_malloc_with_size_tracking_async(uint64_t size, cudaStream_t stream,
void *cuda_malloc_async(uint64_t size, cudaStream_t stream, uint32_t gpu_index);
bool cuda_check_valid_malloc(uint64_t size, uint32_t gpu_index);
uint64_t cuda_device_total_memory(uint32_t gpu_index);
void cuda_memcpy_with_size_tracking_async_to_gpu(void *dest, const void *src,
uint64_t size,

View File

@@ -122,6 +122,13 @@ bool cuda_check_valid_malloc(uint64_t size, uint32_t gpu_index) {
}
}
uint64_t cuda_device_total_memory(uint32_t gpu_index) {
cuda_set_device(gpu_index);
size_t total_mem = 0, free_mem = 0;
check_cuda_error(cudaMemGetInfo(&free_mem, &total_mem));
return total_mem;
}
/// Returns
/// false if Cooperative Groups is not supported.
/// true otherwise

View File

@@ -23,6 +23,7 @@ extern "C" {
pub fn cuda_malloc_async(size: u64, stream: *mut c_void, gpu_index: u32) -> *mut c_void;
pub fn cuda_check_valid_malloc(size: u64, gpu_index: u32) -> bool;
pub fn cuda_device_total_memory(gpu_index: u32) -> u64;
pub fn cuda_memcpy_with_size_tracking_async_to_gpu(
dest: *mut c_void,

View File

@@ -1124,6 +1124,24 @@ pub fn synchronize_devices(gpu_count: u32) {
pub fn check_valid_cuda_malloc(size: u64, gpu_index: GpuIndex) -> bool {
unsafe { cuda_check_valid_malloc(size, gpu_index.get()) }
}
/// Check if a memory allocation fits in GPU memory. If it doesn't fit, panic with
/// a helpful message.
pub fn check_valid_cuda_malloc_assert_oom(size: u64, gpu_index: GpuIndex) {
if !check_valid_cuda_malloc(size, gpu_index) {
let total_memory;
unsafe {
total_memory = cuda_device_total_memory(gpu_index.get());
}
panic!(
"Not enough memory on GPU {}. Allocating {} bytes exceeds total memory: {} bytes",
gpu_index.get(),
size,
total_memory
);
}
}
// Determine if a cuda device is available, at runtime
pub fn is_cuda_available() -> bool {
let result = unsafe { cuda_is_available() };

View File

@@ -1034,50 +1034,26 @@ mod gpu {
let ttrue = FheBool::encrypt(true, &keys);
let ffalse = FheBool::encrypt(false, &keys);
let bitand_size_on_gpu = ttrue.get_bitand_size_on_gpu(&ffalse);
assert!(check_valid_cuda_malloc(
bitand_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitand_size_on_gpu, GpuIndex::new(0));
let scalar_bitand_size_on_gpu = ttrue.get_bitand_size_on_gpu(false);
assert!(check_valid_cuda_malloc(
scalar_bitand_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_bitand_size_on_gpu, GpuIndex::new(0));
let bitxor_size_on_gpu = ttrue.get_bitxor_size_on_gpu(&ffalse);
assert!(check_valid_cuda_malloc(
bitxor_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitxor_size_on_gpu, GpuIndex::new(0));
let scalar_bitxor_size_on_gpu = ttrue.get_bitxor_size_on_gpu(false);
assert!(check_valid_cuda_malloc(
scalar_bitxor_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_bitxor_size_on_gpu, GpuIndex::new(0));
let bitor_size_on_gpu = ttrue.get_bitor_size_on_gpu(&ffalse);
assert!(check_valid_cuda_malloc(bitor_size_on_gpu, GpuIndex::new(0)));
check_valid_cuda_malloc_assert_oom(bitor_size_on_gpu, GpuIndex::new(0));
let scalar_bitor_size_on_gpu = ttrue.get_bitor_size_on_gpu(false);
assert!(check_valid_cuda_malloc(
scalar_bitor_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_bitor_size_on_gpu, GpuIndex::new(0));
let bitnot_size_on_gpu = ttrue.get_bitnot_size_on_gpu();
assert!(check_valid_cuda_malloc(
bitnot_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitnot_size_on_gpu, GpuIndex::new(0));
let eq_size_on_gpu = ttrue.get_eq_size_on_gpu(&ffalse);
assert!(check_valid_cuda_malloc(eq_size_on_gpu, GpuIndex::new(0)));
check_valid_cuda_malloc_assert_oom(eq_size_on_gpu, GpuIndex::new(0));
let scalar_eq_size_on_gpu = ttrue.get_eq_size_on_gpu(false);
assert!(check_valid_cuda_malloc(
scalar_eq_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_eq_size_on_gpu, GpuIndex::new(0));
let ne_size_on_gpu = ttrue.get_ne_size_on_gpu(&ffalse);
assert!(check_valid_cuda_malloc(ne_size_on_gpu, GpuIndex::new(0)));
check_valid_cuda_malloc_assert_oom(ne_size_on_gpu, GpuIndex::new(0));
let scalar_ne_size_on_gpu = ttrue.get_ne_size_on_gpu(false);
assert!(check_valid_cuda_malloc(
scalar_ne_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_ne_size_on_gpu, GpuIndex::new(0));
}
}

View File

@@ -1094,18 +1094,12 @@ mod tests {
.get_decompression_size_on_gpu(0)
.unwrap()
.unwrap();
assert!(check_valid_cuda_malloc(
decompress_ct1_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(decompress_ct1_size_on_gpu, GpuIndex::new(0));
let decompress_ct2_size_on_gpu = compressed_list
.get_decompression_size_on_gpu(1)
.unwrap()
.unwrap();
assert!(check_valid_cuda_malloc(
decompress_ct2_size_on_gpu,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(decompress_ct2_size_on_gpu, GpuIndex::new(0));
}
}
}

View File

@@ -22,15 +22,19 @@ use crate::GpuIndex;
/// let mut a = FheInt32::try_encrypt(clear_a, &client_key).unwrap();
/// let mut b = FheInt32::try_encrypt(clear_b, &client_key).unwrap();
/// let ciphertexts_size = a.get_size_on_gpu() + b.get_size_on_gpu();
/// assert!(check_valid_cuda_malloc(ciphertexts_size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(ciphertexts_size, GpuIndex::new(0));
///
/// a.move_to_current_device();
/// b.move_to_current_device();
///
/// let tmp_buffer_size = a.get_add_size_on_gpu(&b);
/// assert!(check_valid_cuda_malloc(tmp_buffer_size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(tmp_buffer_size, GpuIndex::new(0));
/// a += &b;
/// ```
pub fn check_valid_cuda_malloc(size: u64, gpu_index: GpuIndex) -> bool {
crate::core_crypto::gpu::check_valid_cuda_malloc(size, gpu_index)
}
pub fn check_valid_cuda_malloc_assert_oom(size: u64, gpu_index: GpuIndex) {
crate::core_crypto::gpu::check_valid_cuda_malloc_assert_oom(size, gpu_index);
}

View File

@@ -60,7 +60,7 @@ impl<Id: FheUintId> FheUint<Id> {
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random
///
/// ```rust
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc;
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc_assert_oom;
/// use tfhe::prelude::FheDecrypt;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, GpuIndex, Seed};
///
@@ -73,7 +73,7 @@ impl<Id: FheUintId> FheUint<Id> {
///
/// let size = FheUint8::get_generate_oblivious_pseudo_random_size_on_gpu();
///
/// assert!(check_valid_cuda_malloc(size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(size, GpuIndex::new(0));
/// ```
pub fn get_generate_oblivious_pseudo_random_size_on_gpu() -> u64 {
global_state::with_internal_keys(|key| {
@@ -146,7 +146,7 @@ impl<Id: FheUintId> FheUint<Id> {
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random_bounded
///
/// ```rust
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc;
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc_assert_oom;
/// use tfhe::prelude::FheDecrypt;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, GpuIndex, Seed};
///
@@ -159,7 +159,7 @@ impl<Id: FheUintId> FheUint<Id> {
///
/// let size = FheUint8::get_generate_oblivious_pseudo_random_bounded_size_on_gpu();
///
/// assert!(check_valid_cuda_malloc(size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(size, GpuIndex::new(0));
/// ```
pub fn get_generate_oblivious_pseudo_random_bounded_size_on_gpu() -> u64 {
global_state::with_internal_keys(|key| {
@@ -235,7 +235,7 @@ impl<Id: FheIntId> FheInt<Id> {
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random
///
/// ```rust
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc;
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc_assert_oom;
/// use tfhe::prelude::FheDecrypt;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheInt8, GpuIndex, Seed};
///
@@ -248,7 +248,7 @@ impl<Id: FheIntId> FheInt<Id> {
///
/// let size = FheInt8::get_generate_oblivious_pseudo_random_size_on_gpu();
///
/// assert!(check_valid_cuda_malloc(size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(size, GpuIndex::new(0));
/// ```
pub fn get_generate_oblivious_pseudo_random_size_on_gpu() -> u64 {
global_state::with_internal_keys(|key| {
@@ -322,7 +322,7 @@ impl<Id: FheIntId> FheInt<Id> {
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random_bounded
///
/// ```rust
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc;
/// use tfhe::core_crypto::gpu::check_valid_cuda_malloc_assert_oom;
/// use tfhe::prelude::FheDecrypt;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheInt8, GpuIndex, Seed};
///
@@ -335,7 +335,7 @@ impl<Id: FheIntId> FheInt<Id> {
///
/// let size = FheInt8::get_generate_oblivious_pseudo_random_bounded_size_on_gpu();
///
/// assert!(check_valid_cuda_malloc(size, GpuIndex::new(0)));
/// check_valid_cuda_malloc_assert_oom(size, GpuIndex::new(0));
/// ```
pub fn get_generate_oblivious_pseudo_random_bounded_size_on_gpu() -> u64 {
global_state::with_internal_keys(|key| {

View File

@@ -3,11 +3,11 @@ use crate::high_level_api::integers::signed::tests::{
};
use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
use crate::prelude::{
check_valid_cuda_malloc, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
BitXorSizeOnGpu, DivRemSizeOnGpu, DivSizeOnGpu, FheEncrypt, FheEqSizeOnGpu, FheMaxSizeOnGpu,
FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt, IfThenElseSizeOnGpu, MulSizeOnGpu,
NegSizeOnGpu, RemSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
ShrSizeOnGpu, SubSizeOnGpu,
check_valid_cuda_malloc_assert_oom, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu,
BitOrSizeOnGpu, BitXorSizeOnGpu, DivRemSizeOnGpu, DivSizeOnGpu, FheEncrypt, FheEqSizeOnGpu,
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt, IfThenElseSizeOnGpu,
MulSizeOnGpu, NegSizeOnGpu, RemSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
use crate::{FheBool, FheInt32, FheUint32, GpuIndex};
@@ -94,30 +94,15 @@ fn test_gpu_get_add_sub_size_on_gpu() {
let sub_tmp_buffer_size = a.get_sub_size_on_gpu(b);
let scalar_add_tmp_buffer_size = clear_a.get_add_size_on_gpu(b);
let scalar_sub_tmp_buffer_size = clear_a.get_sub_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
add_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
sub_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_add_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_sub_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(add_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(sub_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_add_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_sub_tmp_buffer_size, GpuIndex::new(0));
assert_eq!(add_tmp_buffer_size, sub_tmp_buffer_size);
assert_eq!(add_tmp_buffer_size, scalar_add_tmp_buffer_size);
assert_eq!(add_tmp_buffer_size, scalar_sub_tmp_buffer_size);
let neg_tmp_buffer_size = a.get_neg_size_on_gpu();
assert!(check_valid_cuda_malloc(
neg_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(neg_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
@@ -135,39 +120,18 @@ fn test_gpu_get_bitops_size_on_gpu() {
let bitand_tmp_buffer_size = a.get_bitand_size_on_gpu(b);
let scalar_bitand_tmp_buffer_size = clear_a.get_bitand_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitand_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitand_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitand_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitand_tmp_buffer_size, GpuIndex::new(0));
let bitor_tmp_buffer_size = a.get_bitor_size_on_gpu(b);
let scalar_bitor_tmp_buffer_size = clear_a.get_bitor_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitor_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitor_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitor_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitor_tmp_buffer_size, GpuIndex::new(0));
let bitxor_tmp_buffer_size = a.get_bitxor_size_on_gpu(b);
let scalar_bitxor_tmp_buffer_size = clear_a.get_bitxor_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitxor_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitxor_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitxor_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitxor_tmp_buffer_size, GpuIndex::new(0));
let bitnot_tmp_buffer_size = a.get_bitnot_size_on_gpu();
assert!(check_valid_cuda_malloc(
bitnot_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitnot_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_comparisons_size_on_gpu() {
@@ -184,84 +148,36 @@ fn test_gpu_get_comparisons_size_on_gpu() {
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
gt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_gt_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(gt_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_gt_tmp_buffer_size, GpuIndex::new(0));
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ge_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ge_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(ge_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_ge_tmp_buffer_size, GpuIndex::new(0));
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
lt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_lt_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(lt_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_lt_tmp_buffer_size, GpuIndex::new(0));
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
le_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_le_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(le_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_le_tmp_buffer_size, GpuIndex::new(0));
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
max_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_max_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(max_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_max_tmp_buffer_size, GpuIndex::new(0));
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
min_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_min_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(min_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_min_tmp_buffer_size, GpuIndex::new(0));
let eq_tmp_buffer_size = a.get_eq_size_on_gpu(b);
let scalar_eq_tmp_buffer_size = a.get_eq_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
eq_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_eq_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(eq_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_eq_tmp_buffer_size, GpuIndex::new(0));
let ne_tmp_buffer_size = a.get_ne_size_on_gpu(b);
let scalar_ne_tmp_buffer_size = a.get_ne_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ne_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ne_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(ne_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_ne_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
@@ -279,44 +195,20 @@ fn test_gpu_get_shift_rotate_size_on_gpu() {
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(left_shift_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_left_shift_tmp_buffer_size, GpuIndex::new(0));
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(right_shift_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_right_shift_tmp_buffer_size, GpuIndex::new(0));
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(rotate_left_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rotate_left_tmp_buffer_size, GpuIndex::new(0));
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(rotate_right_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rotate_right_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
@@ -335,20 +227,11 @@ fn test_gpu_get_if_then_else_size_on_gpu() {
let b = &b;
let if_then_else_tmp_buffer_size = c.get_if_then_else_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
if_then_else_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(if_then_else_tmp_buffer_size, GpuIndex::new(0));
let select_tmp_buffer_size = c.get_select_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
select_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(select_tmp_buffer_size, GpuIndex::new(0));
let cmux_tmp_buffer_size = c.get_cmux_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
cmux_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(cmux_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_mul_size_on_gpu() {
@@ -365,14 +248,8 @@ fn test_gpu_get_mul_size_on_gpu() {
let mul_tmp_buffer_size = a.get_mul_size_on_gpu(b);
let scalar_mul_tmp_buffer_size = b.get_mul_size_on_gpu(clear_a);
assert!(check_valid_cuda_malloc(
mul_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_mul_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(mul_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_mul_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_div_size_on_gpu() {
@@ -390,31 +267,13 @@ fn test_gpu_get_div_size_on_gpu() {
let div_tmp_buffer_size = a.get_div_size_on_gpu(b);
let rem_tmp_buffer_size = a.get_rem_size_on_gpu(b);
let div_rem_tmp_buffer_size = a.get_div_rem_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
div_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
rem_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
div_rem_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(div_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(rem_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(div_rem_tmp_buffer_size, GpuIndex::new(0));
let scalar_div_tmp_buffer_size = a.get_div_size_on_gpu(clear_b);
let scalar_rem_tmp_buffer_size = a.get_rem_size_on_gpu(clear_b);
let scalar_div_rem_tmp_buffer_size = a.get_div_rem_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
scalar_div_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rem_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_div_rem_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_div_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rem_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_div_rem_tmp_buffer_size, GpuIndex::new(0));
}

View File

@@ -1,9 +1,10 @@
use crate::high_level_api::traits::AddSizeOnGpu;
use crate::prelude::{
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
DivRemSizeOnGpu, DivSizeOnGpu, FheEncrypt, FheEqSizeOnGpu, FheMaxSizeOnGpu, FheMinSizeOnGpu,
FheOrdSizeOnGpu, FheTryEncrypt, IfThenElseSizeOnGpu, MulSizeOnGpu, NegSizeOnGpu, RemSizeOnGpu,
RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
check_valid_cuda_malloc_assert_oom, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
BitXorSizeOnGpu, DivRemSizeOnGpu, DivSizeOnGpu, FheEncrypt, FheEqSizeOnGpu, FheMaxSizeOnGpu,
FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt, IfThenElseSizeOnGpu, MulSizeOnGpu,
NegSizeOnGpu, RemSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
ShrSizeOnGpu, SubSizeOnGpu,
};
use crate::shortint::parameters::{
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
@@ -186,30 +187,15 @@ fn test_gpu_get_add_and_sub_size_on_gpu() {
let sub_tmp_buffer_size = a.get_sub_size_on_gpu(b);
let scalar_add_tmp_buffer_size = clear_a.get_add_size_on_gpu(b);
let scalar_sub_tmp_buffer_size = clear_a.get_sub_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
add_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
sub_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_add_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_sub_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(add_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(sub_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_add_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_sub_tmp_buffer_size, GpuIndex::new(0));
assert_eq!(add_tmp_buffer_size, sub_tmp_buffer_size);
assert_eq!(add_tmp_buffer_size, scalar_add_tmp_buffer_size);
assert_eq!(add_tmp_buffer_size, scalar_sub_tmp_buffer_size);
let neg_tmp_buffer_size = a.get_neg_size_on_gpu();
assert!(check_valid_cuda_malloc(
neg_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(neg_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_bitops_size_on_gpu() {
@@ -227,39 +213,18 @@ fn test_gpu_get_bitops_size_on_gpu() {
let bitand_tmp_buffer_size = a.get_bitand_size_on_gpu(b);
let scalar_bitand_tmp_buffer_size = clear_a.get_bitand_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitand_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitand_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitand_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitand_tmp_buffer_size, GpuIndex::new(0));
let bitor_tmp_buffer_size = a.get_bitor_size_on_gpu(b);
let scalar_bitor_tmp_buffer_size = clear_a.get_bitor_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitor_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitor_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitor_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitor_tmp_buffer_size, GpuIndex::new(0));
let bitxor_tmp_buffer_size = a.get_bitxor_size_on_gpu(b);
let scalar_bitxor_tmp_buffer_size = clear_a.get_bitxor_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
bitxor_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_bitxor_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitxor_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_bitxor_tmp_buffer_size, GpuIndex::new(0));
let bitnot_tmp_buffer_size = a.get_bitnot_size_on_gpu();
assert!(check_valid_cuda_malloc(
bitnot_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(bitnot_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_comparisons_size_on_gpu() {
@@ -276,84 +241,36 @@ fn test_gpu_get_comparisons_size_on_gpu() {
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
gt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_gt_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(gt_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_gt_tmp_buffer_size, GpuIndex::new(0));
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ge_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ge_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(ge_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_ge_tmp_buffer_size, GpuIndex::new(0));
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
lt_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_lt_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(lt_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_lt_tmp_buffer_size, GpuIndex::new(0));
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
le_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_le_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(le_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_le_tmp_buffer_size, GpuIndex::new(0));
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
max_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_max_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(max_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_max_tmp_buffer_size, GpuIndex::new(0));
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
min_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_min_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(min_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_min_tmp_buffer_size, GpuIndex::new(0));
let eq_tmp_buffer_size = a.get_eq_size_on_gpu(b);
let scalar_eq_tmp_buffer_size = a.get_eq_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
eq_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_eq_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(eq_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_eq_tmp_buffer_size, GpuIndex::new(0));
let ne_tmp_buffer_size = a.get_ne_size_on_gpu(b);
let scalar_ne_tmp_buffer_size = a.get_ne_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
ne_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_ne_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(ne_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_ne_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
@@ -371,44 +288,20 @@ fn test_gpu_get_shift_rotate_size_on_gpu() {
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_left_shift_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(left_shift_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_left_shift_tmp_buffer_size, GpuIndex::new(0));
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_right_shift_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(right_shift_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_right_shift_tmp_buffer_size, GpuIndex::new(0));
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_left_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(rotate_left_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rotate_left_tmp_buffer_size, GpuIndex::new(0));
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rotate_right_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(rotate_right_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rotate_right_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
@@ -427,20 +320,11 @@ fn test_gpu_get_if_then_else_size_on_gpu() {
let b = &b;
let if_then_else_tmp_buffer_size = c.get_if_then_else_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
if_then_else_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(if_then_else_tmp_buffer_size, GpuIndex::new(0));
let select_tmp_buffer_size = c.get_select_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
select_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(select_tmp_buffer_size, GpuIndex::new(0));
let cmux_tmp_buffer_size = c.get_cmux_size_on_gpu(a, b);
assert!(check_valid_cuda_malloc(
cmux_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(cmux_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_mul_size_on_gpu() {
@@ -458,14 +342,8 @@ fn test_gpu_get_mul_size_on_gpu() {
let mul_tmp_buffer_size = a.get_mul_size_on_gpu(b);
let scalar_mul_tmp_buffer_size = b.get_mul_size_on_gpu(clear_a);
assert!(check_valid_cuda_malloc(
mul_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_mul_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(mul_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_mul_tmp_buffer_size, GpuIndex::new(0));
}
#[test]
fn test_gpu_get_div_size_on_gpu() {
@@ -484,31 +362,13 @@ fn test_gpu_get_div_size_on_gpu() {
let div_tmp_buffer_size = a.get_div_size_on_gpu(b);
let rem_tmp_buffer_size = a.get_rem_size_on_gpu(b);
let div_rem_tmp_buffer_size = a.get_div_rem_size_on_gpu(b);
assert!(check_valid_cuda_malloc(
div_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
rem_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
div_rem_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(div_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(rem_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(div_rem_tmp_buffer_size, GpuIndex::new(0));
let scalar_div_tmp_buffer_size = a.get_div_size_on_gpu(clear_b);
let scalar_rem_tmp_buffer_size = a.get_rem_size_on_gpu(clear_b);
let scalar_div_rem_tmp_buffer_size = a.get_div_rem_size_on_gpu(clear_b);
assert!(check_valid_cuda_malloc(
scalar_div_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_rem_tmp_buffer_size,
GpuIndex::new(0)
));
assert!(check_valid_cuda_malloc(
scalar_div_rem_tmp_buffer_size,
GpuIndex::new(0)
));
check_valid_cuda_malloc_assert_oom(scalar_div_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_rem_tmp_buffer_size, GpuIndex::new(0));
check_valid_cuda_malloc_assert_oom(scalar_div_rem_tmp_buffer_size, GpuIndex::new(0));
}