mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(gpu): remove device synchronize in drop for cudavec
This commit is contained in:
@@ -3,15 +3,19 @@
|
||||
void extend_radix_with_trivial_zero_blocks_msb_64(
|
||||
CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input,
|
||||
CudaStreamsFFI streams) {
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(
|
||||
output, input, CudaStreams(streams));
|
||||
auto cuda_streams = CudaStreams(streams);
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(output, input,
|
||||
cuda_streams);
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
CudaStreamsFFI streams) {
|
||||
|
||||
host_trim_radix_blocks_lsb<uint64_t>(output, input, CudaStreams(streams));
|
||||
auto cuda_streams = CudaStreams(streams);
|
||||
host_trim_radix_blocks_lsb<uint64_t>(output, input, cuda_streams);
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
|
||||
@@ -660,17 +660,17 @@ mod cuda_utils {
|
||||
impl<T: Numeric> CudaIndexes<T> {
|
||||
pub fn new(indexes: &[T], stream: &CudaStreams, stream_index: u32) -> Self {
|
||||
let length = indexes.len();
|
||||
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_input = CudaVec::<T>::new(length, stream, stream_index);
|
||||
let mut d_output = CudaVec::<T>::new(length, stream, stream_index);
|
||||
let mut d_lut = CudaVec::<T>::new(length, stream, stream_index);
|
||||
let zeros = vec![T::ZERO; length];
|
||||
|
||||
unsafe {
|
||||
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
|
||||
stream.synchronize();
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
Self {
|
||||
d_input,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::get_number_of_gpus;
|
||||
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
||||
use crate::core_crypto::gpu::{synchronize_device, CudaStreams};
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
use std::collections::Bound::{Excluded, Included, Unbounded};
|
||||
use std::ffi::c_void;
|
||||
@@ -485,12 +485,10 @@ unsafe impl<T> Send for CudaVec<T> where T: Send + Numeric {}
|
||||
unsafe impl<T> Sync for CudaVec<T> where T: Sync + Numeric {}
|
||||
|
||||
impl<T: Numeric> Drop for CudaVec<T> {
|
||||
/// Free memory for pointer `ptr` synchronously
|
||||
/// Free memory on GPU for pointers in `ptr`
|
||||
fn drop(&mut self) {
|
||||
for (i, &ptr) in self.ptr.iter().enumerate() {
|
||||
// Synchronizes the device to be sure no stream is still using this pointer
|
||||
let gpu_index = self.gpu_indexes[i];
|
||||
synchronize_device(gpu_index.0);
|
||||
unsafe { cuda_drop(ptr, gpu_index.0) };
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user