diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu index d742b91ca..f9042b23e 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu @@ -3,15 +3,19 @@ void extend_radix_with_trivial_zero_blocks_msb_64( CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input, CudaStreamsFFI streams) { - host_extend_radix_with_trivial_zero_blocks_msb( - output, input, CudaStreams(streams)); + auto cuda_streams = CudaStreams(streams); + host_extend_radix_with_trivial_zero_blocks_msb(output, input, + cuda_streams); + cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0)); } void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input, CudaStreamsFFI streams) { - host_trim_radix_blocks_lsb(output, input, CudaStreams(streams)); + auto cuda_streams = CudaStreams(streams); + host_trim_radix_blocks_lsb(output, input, cuda_streams); + cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0)); } uint64_t scratch_cuda_extend_radix_with_sign_msb_64( diff --git a/tfhe-benchmark/src/utilities.rs b/tfhe-benchmark/src/utilities.rs index 10c1939b3..6bf2e8711 100644 --- a/tfhe-benchmark/src/utilities.rs +++ b/tfhe-benchmark/src/utilities.rs @@ -660,17 +660,17 @@ mod cuda_utils { impl CudaIndexes { pub fn new(indexes: &[T], stream: &CudaStreams, stream_index: u32) -> Self { let length = indexes.len(); - let mut d_input = unsafe { CudaVec::::new_async(length, stream, stream_index) }; - let mut d_output = unsafe { CudaVec::::new_async(length, stream, stream_index) }; - let mut d_lut = unsafe { CudaVec::::new_async(length, stream, stream_index) }; + let mut d_input = CudaVec::::new(length, stream, stream_index); + let mut d_output = CudaVec::::new(length, stream, stream_index); + let mut d_lut = CudaVec::::new(length, stream, stream_index); let zeros = vec![T::ZERO; length]; unsafe { d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index); d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index); d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index); + stream.synchronize(); } - stream.synchronize(); Self { d_input, diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index 8283fe3f8..264df364e 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -1,6 +1,6 @@ use super::get_number_of_gpus; use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut}; -use crate::core_crypto::gpu::{synchronize_device, CudaStreams}; +use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::Numeric; use std::collections::Bound::{Excluded, Included, Unbounded}; use std::ffi::c_void; @@ -485,12 +485,10 @@ unsafe impl Send for CudaVec where T: Send + Numeric {} unsafe impl Sync for CudaVec where T: Sync + Numeric {} impl Drop for CudaVec { - /// Free memory for pointer `ptr` synchronously + /// Free memory on GPU for pointers in `ptr` fn drop(&mut self) { for (i, &ptr) in self.ptr.iter().enumerate() { - // Synchronizes the device to be sure no stream is still using this pointer let gpu_index = self.gpu_indexes[i]; - synchronize_device(gpu_index.0); unsafe { cuda_drop(ptr, gpu_index.0) }; } }