chore(gpu): remove device synchronize in drop for cudavec

This commit is contained in:
Agnes Leroy
2025-10-20 11:35:14 +02:00
committed by Agnès Leroy
parent 42644349ef
commit b4b6275ca5
3 changed files with 13 additions and 11 deletions

View File

@@ -3,15 +3,19 @@
void extend_radix_with_trivial_zero_blocks_msb_64(
CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input,
CudaStreamsFFI streams) {
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(
output, input, CudaStreams(streams));
auto cuda_streams = CudaStreams(streams);
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(output, input,
cuda_streams);
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
}
void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
CudaStreamsFFI streams) {
host_trim_radix_blocks_lsb<uint64_t>(output, input, CudaStreams(streams));
auto cuda_streams = CudaStreams(streams);
host_trim_radix_blocks_lsb<uint64_t>(output, input, cuda_streams);
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
}
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(

View File

@@ -660,17 +660,17 @@ mod cuda_utils {
impl<T: Numeric> CudaIndexes<T> {
pub fn new(indexes: &[T], stream: &CudaStreams, stream_index: u32) -> Self {
let length = indexes.len();
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_input = CudaVec::<T>::new(length, stream, stream_index);
let mut d_output = CudaVec::<T>::new(length, stream, stream_index);
let mut d_lut = CudaVec::<T>::new(length, stream, stream_index);
let zeros = vec![T::ZERO; length];
unsafe {
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
stream.synchronize();
}
stream.synchronize();
Self {
d_input,

View File

@@ -1,6 +1,6 @@
use super::get_number_of_gpus;
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
use crate::core_crypto::gpu::{synchronize_device, CudaStreams};
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::Numeric;
use std::collections::Bound::{Excluded, Included, Unbounded};
use std::ffi::c_void;
@@ -485,12 +485,10 @@ unsafe impl<T> Send for CudaVec<T> where T: Send + Numeric {}
unsafe impl<T> Sync for CudaVec<T> where T: Sync + Numeric {}
impl<T: Numeric> Drop for CudaVec<T> {
/// Free memory for pointer `ptr` synchronously
/// Free memory on GPU for pointers in `ptr`
fn drop(&mut self) {
for (i, &ptr) in self.ptr.iter().enumerate() {
// Synchronizes the device to be sure no stream is still using this pointer
let gpu_index = self.gpu_indexes[i];
synchronize_device(gpu_index.0);
unsafe { cuda_drop(ptr, gpu_index.0) };
}
}