mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
chore(gpu): remove device synchronize in drop for cudavec
This commit is contained in:
@@ -3,15 +3,19 @@
|
|||||||
void extend_radix_with_trivial_zero_blocks_msb_64(
|
void extend_radix_with_trivial_zero_blocks_msb_64(
|
||||||
CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input,
|
CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input,
|
||||||
CudaStreamsFFI streams) {
|
CudaStreamsFFI streams) {
|
||||||
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(
|
auto cuda_streams = CudaStreams(streams);
|
||||||
output, input, CudaStreams(streams));
|
host_extend_radix_with_trivial_zero_blocks_msb<uint64_t>(output, input,
|
||||||
|
cuda_streams);
|
||||||
|
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
|
void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
|
||||||
CudaRadixCiphertextFFI const *input,
|
CudaRadixCiphertextFFI const *input,
|
||||||
CudaStreamsFFI streams) {
|
CudaStreamsFFI streams) {
|
||||||
|
|
||||||
host_trim_radix_blocks_lsb<uint64_t>(output, input, CudaStreams(streams));
|
auto cuda_streams = CudaStreams(streams);
|
||||||
|
host_trim_radix_blocks_lsb<uint64_t>(output, input, cuda_streams);
|
||||||
|
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
||||||
|
|||||||
@@ -660,17 +660,17 @@ mod cuda_utils {
|
|||||||
impl<T: Numeric> CudaIndexes<T> {
|
impl<T: Numeric> CudaIndexes<T> {
|
||||||
pub fn new(indexes: &[T], stream: &CudaStreams, stream_index: u32) -> Self {
|
pub fn new(indexes: &[T], stream: &CudaStreams, stream_index: u32) -> Self {
|
||||||
let length = indexes.len();
|
let length = indexes.len();
|
||||||
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
let mut d_input = CudaVec::<T>::new(length, stream, stream_index);
|
||||||
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
let mut d_output = CudaVec::<T>::new(length, stream, stream_index);
|
||||||
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
let mut d_lut = CudaVec::<T>::new(length, stream, stream_index);
|
||||||
let zeros = vec![T::ZERO; length];
|
let zeros = vec![T::ZERO; length];
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||||
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||||
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
|
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
|
||||||
|
stream.synchronize();
|
||||||
}
|
}
|
||||||
stream.synchronize();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
d_input,
|
d_input,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use super::get_number_of_gpus;
|
use super::get_number_of_gpus;
|
||||||
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
||||||
use crate::core_crypto::gpu::{synchronize_device, CudaStreams};
|
use crate::core_crypto::gpu::CudaStreams;
|
||||||
use crate::core_crypto::prelude::Numeric;
|
use crate::core_crypto::prelude::Numeric;
|
||||||
use std::collections::Bound::{Excluded, Included, Unbounded};
|
use std::collections::Bound::{Excluded, Included, Unbounded};
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@@ -485,12 +485,10 @@ unsafe impl<T> Send for CudaVec<T> where T: Send + Numeric {}
|
|||||||
unsafe impl<T> Sync for CudaVec<T> where T: Sync + Numeric {}
|
unsafe impl<T> Sync for CudaVec<T> where T: Sync + Numeric {}
|
||||||
|
|
||||||
impl<T: Numeric> Drop for CudaVec<T> {
|
impl<T: Numeric> Drop for CudaVec<T> {
|
||||||
/// Free memory for pointer `ptr` synchronously
|
/// Free memory on GPU for pointers in `ptr`
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
for (i, &ptr) in self.ptr.iter().enumerate() {
|
for (i, &ptr) in self.ptr.iter().enumerate() {
|
||||||
// Synchronizes the device to be sure no stream is still using this pointer
|
|
||||||
let gpu_index = self.gpu_indexes[i];
|
let gpu_index = self.gpu_indexes[i];
|
||||||
synchronize_device(gpu_index.0);
|
|
||||||
unsafe { cuda_drop(ptr, gpu_index.0) };
|
unsafe { cuda_drop(ptr, gpu_index.0) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user