diff --git a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h index b25eb7e7e..788fe416b 100644 --- a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h +++ b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h @@ -8,7 +8,7 @@ extern std::mutex m; extern bool p2p_enabled; extern "C" { -int32_t cuda_setup_multi_gpu(int device_0_id); +int32_t cuda_setup_multi_gpu(); } // Define a variant type that can be either a vector or a single pointer diff --git a/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu b/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu index 97bd04f10..c7f0fdeab 100644 --- a/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu +++ b/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu @@ -6,8 +6,7 @@ std::mutex m; bool p2p_enabled = false; -// Enable bidirectional p2p access between all available GPUs and device_0_id -int32_t cuda_setup_multi_gpu(int device_0_id) { +int32_t cuda_setup_multi_gpu() { int num_gpus = cuda_get_number_of_gpus(); if (num_gpus == 0) PANIC("GPU error: the number of GPUs should be > 0.") @@ -19,13 +18,11 @@ int32_t cuda_setup_multi_gpu(int device_0_id) { omp_set_nested(1); int has_peer_access_to_device_0; for (int i = 1; i < num_gpus; i++) { - check_cuda_error(cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, - i, device_0_id)); + check_cuda_error( + cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, i, 0)); if (has_peer_access_to_device_0) { cuda_set_device(i); - check_cuda_error(cudaDeviceEnablePeerAccess(device_0_id, 0)); - cuda_set_device(device_0_id); - check_cuda_error(cudaDeviceEnablePeerAccess(i, 0)); + check_cuda_error(cudaDeviceEnablePeerAccess(0, 0)); } num_used_gpus += 1; } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp index 3fea7138d..7fb1f7533 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp @@ -65,7 +65,7 @@ public: number_of_inputs = (int)GetParam().number_of_inputs; // Enable Multi-GPU logic - gpu_count = cuda_setup_multi_gpu(0); + gpu_count = cuda_setup_multi_gpu(); active_gpu_count = std::min((uint)number_of_inputs, gpu_count); for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) { streams.push_back(cuda_create_stream(gpu_i)); diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 114f2dc05..31e04d12b 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -101,6 +101,6 @@ extern "C" { pub fn cuda_drop_async(ptr: *mut c_void, stream: *mut c_void, gpu_index: u32); - pub fn cuda_setup_multi_gpu(gpu_index: u32) -> i32; + pub fn cuda_setup_multi_gpu() -> i32; } // extern "C" diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index d454bc9cd..100e72b97 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -17,6 +17,7 @@ pub use entities::*; use std::ffi::c_void; use tfhe_cuda_backend::bindings::*; use tfhe_cuda_backend::cuda_bind::*; + pub struct CudaStreams { pub ptr: Vec<*mut c_void>, pub gpu_indexes: Vec, @@ -29,7 +30,7 @@ unsafe impl Sync for CudaStreams {} impl CudaStreams { /// Create a new `CudaStreams` structure with as many GPUs as there are on the machine pub fn new_multi_gpu() -> Self { - let gpu_count = setup_multi_gpu(GpuIndex::new(0)); + let gpu_count = setup_multi_gpu(); let mut gpu_indexes = Vec::with_capacity(gpu_count as usize); let mut ptr_array = Vec::with_capacity(gpu_count as usize); @@ -42,22 +43,6 @@ impl CudaStreams { gpu_indexes, } } - /// Create a new `CudaStreams` structure with the GPUs with id provided in a list - pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self { - let _gpu_count = setup_multi_gpu(indexes[0]); - - let mut gpu_indexes = Vec::with_capacity(indexes.len()); - let mut ptr_array = Vec::with_capacity(indexes.len()); - - for &i in indexes { - ptr_array.push(unsafe { cuda_create_stream(i.get()) }); - gpu_indexes.push(i); - } - Self { - ptr: ptr_array, - gpu_indexes, - } - } /// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given /// as input pub fn new_single_gpu(gpu_index: GpuIndex) -> Self { @@ -103,14 +88,6 @@ impl CudaStreams { } } -impl Clone for CudaStreams { - fn clone(&self) -> Self { - // The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of - // streams being cloned (single, multi, or custom) - Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice()) - } -} - impl Drop for CudaStreams { fn drop(&mut self) { for (i, &s) in self.ptr.iter().enumerate() { @@ -1059,8 +1036,8 @@ pub fn get_number_of_gpus() -> u32 { } /// Setup multi-GPU and return the number of GPUs used -pub fn setup_multi_gpu(device_0_id: GpuIndex) -> u32 { - unsafe { cuda_setup_multi_gpu(device_0_id.get()) as u32 } +pub fn setup_multi_gpu() -> u32 { + unsafe { cuda_setup_multi_gpu() as u32 } } /// Synchronize device diff --git a/tfhe/src/high_level_api/array/gpu/booleans.rs b/tfhe/src/high_level_api/array/gpu/booleans.rs index a858d1e3c..7d0b9cf17 100644 --- a/tfhe/src/high_level_api/array/gpu/booleans.rs +++ b/tfhe/src/high_level_api/array/gpu/booleans.rs @@ -156,13 +156,14 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams)) + }) + .collect::>() + }) })) } @@ -171,13 +172,14 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams)) + }) + .collect::>() + }) })) } @@ -186,22 +188,24 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams)) + }) + .collect::>() + }) })) } fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams))) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams))) + .collect::>() + }) })) } } @@ -212,13 +216,16 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams), + ) + }) + .collect::>() + }) })) } @@ -227,13 +234,16 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams), + ) + }) + .collect::>() + }) })) } @@ -242,13 +252,16 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams)) - }) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams), + ) + }) + .collect::>() + }) })) } } diff --git a/tfhe/src/high_level_api/array/gpu/integers.rs b/tfhe/src/high_level_api/array/gpu/integers.rs index 6a2b6ed84..1cf12a260 100644 --- a/tfhe/src/high_level_api/array/gpu/integers.rs +++ b/tfhe/src/high_level_api/array/gpu/integers.rs @@ -108,11 +108,12 @@ where F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T, { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams)) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams)) + .collect::>() + }) })) } @@ -169,11 +170,12 @@ where F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T, { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams)) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams)) + .collect::>() + }) })) } @@ -334,10 +336,11 @@ where fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - let streams = &cuda_key.streams; - lhs.par_iter() - .map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams)) - .collect::>() + with_thread_local_cuda_streams(|streams| { + lhs.par_iter() + .map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams)) + .collect::>() + }) })) } } diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 2cba6f610..db05aac24 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -3,6 +3,8 @@ use crate::backward_compatibility::booleans::FheBoolVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric}; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId}; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{FheEq, IfThenElse, ScalarIfThenElse, Tagged}; @@ -383,8 +385,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool { (InnerBoolean::Cpu(new_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -393,7 +394,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool { ); let boolean_inner = CudaBooleanBlock(inner); (InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone()) - } + }), }); Self::new(ciphertext, tag) } @@ -421,8 +422,7 @@ where FheUint::new(inner, cpu_sks.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -431,7 +431,7 @@ where ); FheUint::new(inner, cuda_key.tag.clone()) - } + }), }) } } @@ -455,8 +455,7 @@ impl IfThenElse> for FheBool { FheInt::new(new_ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -465,7 +464,7 @@ impl IfThenElse> for FheBool { ); FheInt::new(inner, cuda_key.tag.clone()) - } + }), }) } } @@ -483,8 +482,7 @@ impl IfThenElse for FheBool { (InnerBoolean::Cpu(new_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -493,7 +491,7 @@ impl IfThenElse for FheBool { ); let boolean_inner = CudaBooleanBlock(inner); (InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone()) - } + }), }); Self::new(ciphertext, tag) } @@ -543,8 +541,7 @@ where Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &other.borrow().ciphertext.on_gpu(streams), @@ -552,7 +549,7 @@ where ); let ciphertext = InnerBoolean::Cuda(inner); Self::new(ciphertext, cuda_key.tag.clone()) - } + }), }) } @@ -586,8 +583,7 @@ where Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &other.borrow().ciphertext.on_gpu(streams), @@ -595,7 +591,7 @@ where ); let ciphertext = InnerBoolean::Cuda(inner); Self::new(ciphertext, cuda_key.tag.clone()) - } + }), }) } } @@ -632,15 +628,14 @@ impl FheEq for FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.scalar_eq( &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) - } + }), }); Self::new(ciphertext, tag) } @@ -676,15 +671,14 @@ impl FheEq for FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.scalar_ne( &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) - } + }), }); Self::new(ciphertext, tag) } @@ -751,8 +745,7 @@ where (InnerBoolean::Cpu(inner_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitand( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -765,7 +758,7 @@ where )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -837,8 +830,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitor( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -850,7 +842,7 @@ where )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -922,8 +914,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitxor( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -935,7 +926,7 @@ where )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -999,8 +990,7 @@ impl BitAnd for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitand( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1012,7 +1002,7 @@ impl BitAnd for &FheBool { )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -1076,8 +1066,7 @@ impl BitOr for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitor( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1089,7 +1078,7 @@ impl BitOr for &FheBool { )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -1153,8 +1142,7 @@ impl BitXor for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitxor( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1166,7 +1154,7 @@ impl BitXor for &FheBool { )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } @@ -1358,14 +1346,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitand_assign( self.ciphertext.as_gpu_mut(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }); } } @@ -1402,14 +1389,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }); } } @@ -1446,14 +1432,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitxor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }); } } @@ -1484,14 +1469,13 @@ impl BitAndAssign for FheBool { .scalar_bitand_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - } + }), }); } } @@ -1522,14 +1506,13 @@ impl BitOrAssign for FheBool { .scalar_bitor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - } + }), }); } } @@ -1560,14 +1543,13 @@ impl BitXorAssign for FheBool { .scalar_bitxor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - } + }), }); } } @@ -1624,8 +1606,7 @@ impl std::ops::Not for &FheBool { (InnerBoolean::Cpu(inner), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key .key @@ -1637,7 +1618,7 @@ impl std::ops::Not for &FheBool { )), cuda_key.tag.clone(), ) - } + }), }); FheBool::new(ciphertext, tag) } diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index 56ef8b29b..c339a2b8d 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -1,6 +1,8 @@ use super::base::FheBool; use crate::high_level_api::booleans::inner::InnerBoolean; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; @@ -88,8 +90,7 @@ impl FheTryTrivialEncrypt for FheBool { (ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner: CudaUnsignedRadixCiphertext = cuda_key .key @@ -99,7 +100,7 @@ impl FheTryTrivialEncrypt for FheBool { inner.into_inner(), )); (ct, cuda_key.tag.clone()) - } + }), }); Ok(Self::new(ciphertext, tag)) } diff --git a/tfhe/src/high_level_api/booleans/oprf.rs b/tfhe/src/high_level_api/booleans/oprf.rs index ebca57790..270ef554a 100644 --- a/tfhe/src/high_level_api/booleans/oprf.rs +++ b/tfhe/src/high_level_api/booleans/oprf.rs @@ -1,5 +1,7 @@ use super::{FheBool, InnerBoolean}; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; @@ -39,8 +41,7 @@ impl FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -51,7 +52,7 @@ impl FheBool { )), cuda_key.tag.clone(), ) - } + }), }); Self::new(ciphertext, tag) } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index cfc37216c..45849923c 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -172,16 +172,17 @@ impl CompressedCiphertextListBuilder { for (element, _) in &self.inner { match element { ToBeCompressed::Cpu(cpu_blocks) => { - let streams = &cuda_key.streams; - cuda_radixes - .push(CudaRadixCiphertext::from_cpu_blocks(cpu_blocks, streams)); + with_thread_local_cuda_streams(|streams| { + cuda_radixes.push(CudaRadixCiphertext::from_cpu_blocks( + cpu_blocks, streams, + )); + }) } #[cfg(feature = "gpu")] ToBeCompressed::Cuda(cuda_radix) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_radixes.push(cuda_radix.duplicate(streams)); - }; + }); } } } @@ -194,11 +195,10 @@ impl CompressedCiphertextListBuilder { crate::Error::new("Compression key not set in server key".to_owned()) }) .map(|compression_key| { - let packed_list = { - let streams = &cuda_key.streams; + let packed_list = with_thread_local_cuda_streams(|streams| { compression_key .compress_ciphertexts_into_list(cuda_radixes.as_slice(), streams) - }; + }); let info = self.inner.iter().map(|(_, kind)| *kind).collect(); let compressed_list = CudaCompressedCiphertextList { packed_list, info }; @@ -458,12 +458,11 @@ impl CiphertextList for CompressedCiphertextList { crate::Error::new("Compression key not set in server key".to_owned()) }) .and_then(|decompression_key| { - let mut ct = { - let streams = &cuda_key.streams; + let mut ct = with_thread_local_cuda_streams(|streams| { self.inner .on_gpu(streams) .get::(index, decompression_key, streams) - }; + }); if let Ok(Some(ct_ref)) = &mut ct { ct_ref.tag_mut().set_data(cuda_key.tag.data()) } diff --git a/tfhe/src/high_level_api/global_state.rs b/tfhe/src/high_level_api/global_state.rs index 52f3b1896..f4f3e9a3d 100644 --- a/tfhe/src/high_level_api/global_state.rs +++ b/tfhe/src/high_level_api/global_state.rs @@ -201,15 +201,12 @@ pub(in crate::high_level_api) use gpu::{ #[cfg(feature = "gpu")] pub use gpu::CudaGpuChoice; -#[cfg(feature = "gpu")] -pub struct CustomMultiGpuIndexes(Vec); #[cfg(feature = "gpu")] mod gpu { use crate::core_crypto::gpu::get_number_of_gpus; use super::*; - use itertools::Itertools; use std::cell::LazyCell; thread_local! { @@ -226,14 +223,14 @@ mod gpu { } struct CudaStreamPool { - custom: Option, + multi: LazyCell, single: Vec CudaStreams>>>, } impl CudaStreamPool { fn new() -> Self { Self { - custom: None, + multi: LazyCell::new(CudaStreams::new_multi_gpu), single: (0..get_number_of_gpus()) .map(|index| { let ctor = @@ -245,6 +242,29 @@ mod gpu { } } + impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool { + type Output = CudaStreams; + + fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output { + match indexes.len() { + 0 => panic!("Internal error: Gpu indexes must not be empty"), + 1 => &self.single[indexes[0].get() as usize], + _ => &self.multi, + } + } + } + + impl std::ops::Index for CudaStreamPool { + type Output = CudaStreams; + + fn index(&self, choice: CudaGpuChoice) -> &Self::Output { + match choice { + CudaGpuChoice::Multi => &self.multi, + CudaGpuChoice::Single(index) => &self.single[index.get() as usize], + } + } + } + pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes< R, F: for<'a> FnOnce(&'a CudaStreams) -> R, @@ -255,44 +275,15 @@ mod gpu { thread_local! { static POOL: RefCell = RefCell::new(CudaStreamPool::new()); } - - if gpu_indexes.len() == 1 { - POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize])) - } else { - POOL.with_borrow_mut(|pool| match &pool.custom { - Some(streams) if streams.gpu_indexes != gpu_indexes => { - pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes)); - } - None => { - pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes)); - } - _ => {} - }); - - POOL.with_borrow(|pool| func(pool.custom.as_ref().unwrap())) - } + POOL.with_borrow(|stream_pool| { + let stream = &stream_pool[gpu_indexes]; + func(stream) + }) } - - impl Clone for CustomMultiGpuIndexes { - fn clone(&self) -> Self { - self.0.iter().copied().collect_vec().into() - } - } - - impl CustomMultiGpuIndexes { - pub fn new(indexes: Vec) -> Self { - Self(indexes) - } - pub fn gpu_indexes(&self) -> &[GpuIndex] { - self.0.as_slice() - } - } - - #[derive(Clone)] + #[derive(Copy, Clone)] pub enum CudaGpuChoice { Single(GpuIndex), Multi, - Custom(CustomMultiGpuIndexes), } impl From for CudaGpuChoice { @@ -301,24 +292,11 @@ mod gpu { } } - impl From> for CustomMultiGpuIndexes { - fn from(value: Vec) -> Self { - Self(value) - } - } - - impl From for CudaGpuChoice { - fn from(values: CustomMultiGpuIndexes) -> Self { - Self::Custom(values) - } - } - impl CudaGpuChoice { pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams { match self { Self::Single(idx) => CudaStreams::new_single_gpu(idx), Self::Multi => CudaStreams::new_multi_gpu(), - Self::Custom(idxs) => CudaStreams::new_multi_gpu_with_indexes(idxs.gpu_indexes()), } } } diff --git a/tfhe/src/high_level_api/integers/oprf.rs b/tfhe/src/high_level_api/integers/oprf.rs index 4280a5994..1c40f5e70 100644 --- a/tfhe/src/high_level_api/integers/oprf.rs +++ b/tfhe/src/high_level_api/integers/oprf.rs @@ -1,5 +1,7 @@ use super::{FheIntId, FheUint, FheUintId}; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; @@ -36,8 +38,7 @@ impl FheUint { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -48,7 +49,7 @@ impl FheUint { ); Self::new(d_ct, cuda_key.tag.clone()) - } + }), }) } /// Generates an encrypted `num_block` blocks unsigned integer @@ -86,8 +87,7 @@ impl FheUint { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -98,7 +98,7 @@ impl FheUint { streams, ); Self::new(d_ct, cuda_key.tag.clone()) - } + }), }) } } @@ -136,8 +136,7 @@ impl FheInt { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let d_ct: CudaSignedRadixCiphertext = cuda_key .key .key @@ -148,7 +147,7 @@ impl FheInt { ); Self::new(d_ct, cuda_key.tag.clone()) - } + }), }) } @@ -188,8 +187,7 @@ impl FheInt { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let d_ct: CudaSignedRadixCiphertext = cuda_key .key .key @@ -200,7 +198,7 @@ impl FheInt { streams, ); Self::new(d_ct, cuda_key.tag.clone()) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 5f28ed35f..675b19e94 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -17,6 +17,8 @@ use crate::shortint::PBSParameters; use crate::{Device, FheBool, ServerKey, Tag}; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; pub trait FheIntId: IntegerId {} /// A Generic FHE signed integer @@ -195,14 +197,13 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key .abs(&*self.ciphertext.on_gpu(streams), streams); Self::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -232,14 +233,13 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -269,14 +269,13 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -310,8 +309,7 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -322,7 +320,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -356,8 +354,7 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -368,7 +365,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -402,8 +399,7 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -414,7 +410,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -448,8 +444,7 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -460,7 +455,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -570,8 +565,7 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -582,7 +576,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -625,8 +619,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, is_ok) = cuda_key .key .key @@ -640,7 +633,7 @@ where crate::FheUint32::new(result, cuda_key.tag.clone()), FheBool::new(is_ok, cuda_key.tag.clone()), ) - } + }), }) } @@ -814,8 +807,7 @@ where Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus()); let new_ciphertext = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams), @@ -823,7 +815,7 @@ where streams, ); Self::new(new_ciphertext, cuda_key.tag.clone()) - } + }), }) } } @@ -860,15 +852,14 @@ where Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let new_ciphertext = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(new_ciphertext, cuda_key.tag.clone()) - } + }), }) } } @@ -908,15 +899,14 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(inner, cuda_key.tag.clone()) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/signed/encrypt.rs b/tfhe/src/high_level_api/integers/signed/encrypt.rs index 2f2a3fc86..d4e284fd5 100644 --- a/tfhe/src/high_level_api/integers/signed/encrypt.rs +++ b/tfhe/src/high_level_api/integers/signed/encrypt.rs @@ -1,5 +1,7 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger}; @@ -111,15 +113,14 @@ where Ok(Self::new(ciphertext, key.tag.clone())) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( value, Id::num_blocks(cuda_key.key.key.message_modulus), streams, ); Ok(Self::new(inner, cuda_key.tag.clone())) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index 07497d6b1..ebac40a4d 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -18,6 +18,9 @@ use std::ops::{ Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; + impl<'a, Id> std::iter::Sum<&'a Self> for FheInt where Id: FheIntId, @@ -71,8 +74,7 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let cts = iter .map(|fhe_uint| { match fhe_uint.ciphertext.on_gpu(streams) { @@ -102,7 +104,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - } + }) } }) } @@ -142,15 +144,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.max( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -189,15 +190,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.min( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -247,15 +247,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -287,15 +286,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -353,15 +351,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.lt( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -393,15 +390,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.le( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -433,15 +429,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.gt( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -473,15 +468,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ge( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -556,8 +550,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (q, r) = cuda_key.key.key.div_rem( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), @@ -567,7 +560,7 @@ where FheInt::::new(q, cuda_key.tag.clone()), FheInt::::new(r, cuda_key.tag.clone()), ) - } + }), }) } } @@ -641,11 +634,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -684,11 +677,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -727,11 +720,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -768,11 +761,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -809,11 +802,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -850,11 +843,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } }) } @@ -898,14 +891,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }, + }), }) } }, @@ -949,14 +942,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }, + }), }) } }, @@ -1064,11 +1057,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } } }) @@ -1108,11 +1101,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } } }) @@ -1152,11 +1145,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } } }) @@ -1196,11 +1189,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }) } } }) @@ -1246,12 +1239,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.add_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.add_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1292,12 +1286,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.sub_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.sub_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1338,12 +1333,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.mul_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.mul_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1382,12 +1378,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.bitand_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.bitand_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1426,12 +1423,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.bitor_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.bitor_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1470,12 +1468,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.bitxor_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.bitxor_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }) } }) } @@ -1519,8 +1518,7 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let cuda_lhs = self.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().div( &*cuda_lhs, @@ -1528,7 +1526,7 @@ where streams, ); *cuda_lhs = cuda_result; - }; + }); } }) } @@ -1572,13 +1570,15 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let cuda_lhs = self.ciphertext.as_gpu_mut(streams); - let cuda_result = - cuda_key - .pbs_key() - .rem(&*cuda_lhs, &rhs.ciphertext.on_gpu(streams), streams); - *cuda_lhs = cuda_result; + with_thread_local_cuda_streams(|streams| { + let cuda_lhs = self.ciphertext.as_gpu_mut(streams); + let cuda_result = cuda_key.pbs_key().rem( + &*cuda_lhs, + &rhs.ciphertext.on_gpu(streams), + streams, + ); + *cuda_lhs = cuda_result; + }); } }) } @@ -1627,12 +1627,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.left_shift_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.left_shift_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }); } }) } @@ -1680,12 +1681,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.right_shift_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.right_shift_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }); } }) } @@ -1734,12 +1736,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.rotate_left_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.rotate_left_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }); } }) } @@ -1788,12 +1791,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.rotate_right_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }); } }) } @@ -1859,14 +1863,13 @@ where FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .neg(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -1929,14 +1932,13 @@ where FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -1954,14 +1956,13 @@ where InternalServerKey::Cpu(_) => { tmp_buffer_size = 0; } - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }); tmp_buffer_size } @@ -1974,13 +1975,12 @@ where { fn get_size_on_gpu(&self) -> u64 { global_state::with_internal_keys(|key| match key { - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key .key .key .get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams)) - } + }), InternalServerKey::Cpu(_) => 0, }) } diff --git a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs index 01c2b34ce..840d1c443 100644 --- a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs @@ -1,5 +1,7 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::DecomposableInto; @@ -51,8 +53,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_add( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -62,7 +63,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - } + }), }) } } @@ -148,8 +149,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add( &self.ciphertext.on_gpu(streams), other, @@ -159,7 +159,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - } + }), }) } } @@ -283,8 +283,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_sub( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -294,7 +293,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - } + }), }) } } @@ -379,8 +378,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub( &self.ciphertext.on_gpu(streams), other, @@ -390,7 +388,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index 311e888a3..6dc5aa41c 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -2,6 +2,8 @@ use crate::core_crypto::commons::numeric::CastFrom; use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; @@ -53,13 +55,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams); - Self::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = cuda_key.key.key.scalar_max( + &*self.ciphertext.on_gpu(streams), + rhs, + streams, + ); + Self::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -100,13 +103,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams); - Self::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = cuda_key.key.key.scalar_min( + &*self.ciphertext.on_gpu(streams), + rhs, + streams, + ); + Self::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -146,13 +150,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -186,13 +191,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -231,13 +237,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -270,13 +277,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -309,13 +317,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -348,13 +357,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - let inner_result = - cuda_key - .key - .key - .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) + crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -398,11 +408,11 @@ macro_rules! generic_integer_impl_scalar_div_rem { } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let (inner_q, inner_r) = {let streams = &cuda_key.streams; + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_div_rem( &*self.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); let (q, r) = ( SignedRadixCiphertext::Cuda(inner_q), SignedRadixCiphertext::Cuda(inner_r), @@ -446,11 +456,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_left_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -475,11 +485,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_right_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -504,11 +514,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_left( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -533,11 +543,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_right( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -560,10 +570,11 @@ macro_rules! define_scalar_rotate_shifts { .scalar_left_shift_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => - {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => { + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -586,9 +597,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -610,10 +622,11 @@ macro_rules! define_scalar_rotate_shifts { .scalar_rotate_left_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => - {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => { + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -636,10 +649,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - } + }) } }) } @@ -749,11 +762,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_add( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -778,11 +791,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_sub( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -807,11 +820,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_mul( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -837,11 +850,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -866,11 +879,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -896,11 +909,11 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -925,11 +938,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_div( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -955,11 +968,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_rem( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); SignedRadixCiphertext::Cuda(inner_result) } }) @@ -1001,11 +1014,12 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let mut result: CudaSignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams); SignedRadixCiphertext::Cuda(result) + }) } }) } @@ -1096,9 +1110,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1126,9 +1141,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1152,9 +1168,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1179,9 +1196,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1205,9 +1223,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1230,9 +1249,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1254,11 +1274,11 @@ macro_rules! define_scalar_ops { .signed_scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().signed_scalar_div(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - } + }) }) } }, @@ -1279,11 +1299,11 @@ macro_rules! define_scalar_ops { .signed_scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().signed_scalar_rem(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - } + }) }) } }, diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 00b8f35d6..972ff5423 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -4,6 +4,8 @@ use super::inner::RadixCiphertext; use crate::backward_compatibility::integers::FheUintVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric}; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::signed::{FheInt, FheIntId}; use crate::high_level_api::integers::IntegerId; use crate::high_level_api::keys::InternalServerKey; @@ -247,14 +249,13 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -284,14 +285,13 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -418,8 +418,7 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -430,7 +429,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -464,8 +463,7 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -476,7 +474,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -510,8 +508,7 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -522,7 +519,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -556,8 +553,7 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -568,7 +564,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -678,8 +674,7 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key .key .key @@ -690,7 +685,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - } + }), }) } @@ -733,8 +728,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, is_ok) = cuda_key .key .key @@ -748,7 +742,7 @@ where super::FheUint32::new(result, cuda_key.tag.clone()), FheBool::new(is_ok, cuda_key.tag.clone()), ) - } + }), }) } @@ -816,8 +810,7 @@ where } } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, matched) = cuda_key.key.key.match_value( &self.ciphertext.on_gpu(streams), matches, @@ -832,7 +825,7 @@ where } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } - } + }), }) } @@ -895,8 +888,7 @@ where } } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.key.match_value_or( &self.ciphertext.on_gpu(streams), matches, @@ -909,7 +901,7 @@ where } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } - } + }), }) } @@ -1109,15 +1101,14 @@ where Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let casted = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(casted, cuda_key.tag.clone()) - } + }), }) } } @@ -1154,15 +1145,14 @@ where Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let casted = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(casted, cuda_key.tag.clone()) - } + }), }) } } @@ -1199,15 +1189,14 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(inner, cuda_key.tag.clone()) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index 135cf8c0a..593e580e6 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -1,5 +1,7 @@ use crate::core_crypto::prelude::UnsignedNumeric; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; @@ -113,15 +115,14 @@ where Ok(Self::new(ciphertext, key.tag.clone())) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( value, Id::num_blocks(cuda_key.key.key.message_modulus), streams, ); Ok(Self::new(inner, cuda_key.tag.clone())) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 38d54179d..1efe9cce8 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -5,6 +5,8 @@ use super::inner::RadixCiphertext; #[cfg(feature = "gpu")] use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] @@ -73,8 +75,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let cts = iter .map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams)) .collect::>(); @@ -91,7 +92,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - } + }), }) } } @@ -153,8 +154,7 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let cts = iter .map(|fhe_uint| { match fhe_uint.ciphertext.on_gpu(streams) { @@ -184,7 +184,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - } + }) } }) } @@ -224,15 +224,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.max( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -271,15 +270,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.min( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -329,15 +327,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -369,15 +366,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -435,15 +431,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.lt( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -475,15 +470,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.le( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -515,15 +509,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.gt( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -555,15 +548,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ge( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -639,8 +631,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.div_rem( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), @@ -650,7 +641,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheUint::::new(inner_result.1, cuda_key.tag.clone()), ) - } + }), }) } } @@ -726,10 +717,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -768,10 +760,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -810,10 +803,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -850,10 +844,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -890,10 +885,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -930,10 +926,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } }) } @@ -977,15 +974,14 @@ generic_integer_impl_operation!( FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }, + }), }) } }, @@ -1029,16 +1025,14 @@ generic_integer_impl_operation!( FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => - { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }, + }), }) } }, @@ -1146,10 +1140,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } } }) @@ -1189,10 +1184,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } } }) @@ -1232,10 +1228,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } } }) @@ -1275,10 +1272,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) + }) } } }) @@ -1323,14 +1321,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.add_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1369,14 +1366,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.sub_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1415,14 +1411,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.mul_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1459,14 +1454,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitand_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1503,14 +1497,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1547,14 +1540,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitxor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1596,14 +1588,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.div_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1645,14 +1636,13 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rem_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }) } } @@ -1700,14 +1690,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key.left_shift_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }; + }); } }) } @@ -1755,14 +1744,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key.right_shift_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }; + }); } }) } @@ -1811,14 +1799,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rotate_left_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }; + }); } }) } @@ -1867,12 +1854,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; - cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); + with_thread_local_cuda_streams(|streams| { + cuda_key.key.key.rotate_right_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); + }); } }) } @@ -1946,14 +1934,13 @@ where FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .neg(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -2016,14 +2003,13 @@ where FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -2040,14 +2026,13 @@ where InternalServerKey::Cpu(_) => { tmp_buffer_size = 0; } - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - } + }), }); tmp_buffer_size } @@ -2059,13 +2044,12 @@ where { fn get_size_on_gpu(&self) -> u64 { global_state::with_internal_keys(|key| match key { - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key .key .key .get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams)) - } + }), InternalServerKey::Cpu(_) => 0, }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs index 74e803263..9b2e274d4 100644 --- a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs @@ -1,5 +1,7 @@ use crate::core_crypto::prelude::UnsignedNumeric; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::DecomposableInto; @@ -51,8 +53,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_add( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -62,7 +63,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - } + }), }) } } @@ -148,8 +149,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add( &self.ciphertext.on_gpu(streams), other, @@ -159,7 +159,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - } + }), }) } } @@ -285,8 +285,7 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_sub( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -296,7 +295,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - } + }), }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 0e9150bb8..6f93356ef 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -7,6 +7,8 @@ use super::inner::RadixCiphertext; use crate::error::InvalidRangeError; use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; +#[cfg(feature = "gpu")] +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{ @@ -58,15 +60,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -98,15 +99,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -144,15 +144,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -184,15 +183,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -224,15 +222,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } @@ -264,15 +261,14 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -312,15 +308,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -360,15 +355,14 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .key .scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) - } + }), }) } } @@ -499,11 +493,11 @@ macro_rules! generic_integer_impl_scalar_div_rem { } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let (inner_q, inner_r) = {let streams = &cuda_key.streams; + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_div_rem( &*self.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); ( <$concrete_type>::new(q, cuda_key.tag.clone()), @@ -679,11 +673,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_left_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -708,11 +702,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_right_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -737,11 +731,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_left( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -766,11 +760,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_right( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -794,10 +788,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - } + }) } }) } @@ -820,10 +814,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - {let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - } + }) } }) } @@ -846,9 +840,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -871,9 +866,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -984,11 +980,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_add( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1013,11 +1009,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_sub( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1042,11 +1038,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_mul( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1072,11 +1068,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1101,11 +1097,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1131,11 +1127,11 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1160,11 +1156,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_div( &lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1190,11 +1186,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = {let streams = &cuda_key.streams; + let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rem( &lhs.ciphertext.on_gpu(streams), rhs, streams ) - }; + }); RadixCiphertext::Cuda(inner_result) } }) @@ -1235,11 +1231,12 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); RadixCiphertext::Cuda(result) + }) } }) } @@ -1330,9 +1327,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1360,9 +1358,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1386,9 +1385,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1413,9 +1413,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1439,9 +1440,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1464,9 +1466,10 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let streams = &cuda_key.streams; + with_thread_local_cuda_streams(|streams| { cuda_key.key.key .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); + }) } }) } @@ -1488,11 +1491,11 @@ macro_rules! define_scalar_ops { .scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().scalar_div(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - } + }) }) } }, @@ -1513,11 +1516,11 @@ macro_rules! define_scalar_ops { .scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; + InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().scalar_rem(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - } + }) }) } }, diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 8cc8f0dd4..26bb0703e 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -328,7 +328,6 @@ impl CompressedServerKey { decompression_key, }), tag: self.tag.clone(), - streams, } } } @@ -352,7 +351,6 @@ impl Named for CompressedServerKey { pub struct CudaServerKey { pub(crate) key: Arc, pub(crate) tag: Tag, - pub(crate) streams: CudaStreams, } #[cfg(feature = "gpu")] diff --git a/tfhe/src/high_level_api/tests/gpu_selection.rs b/tfhe/src/high_level_api/tests/gpu_selection.rs index e22caa1c1..c1d1e8b49 100644 --- a/tfhe/src/high_level_api/tests/gpu_selection.rs +++ b/tfhe/src/high_level_api/tests/gpu_selection.rs @@ -1,7 +1,6 @@ use rand::Rng; use crate::core_crypto::gpu::get_number_of_gpus; -use crate::high_level_api::global_state::CustomMultiGpuIndexes; use crate::prelude::*; use crate::{ set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex, @@ -118,68 +117,3 @@ fn test_gpu_selection_2() { assert_eq!(c.gpu_indexes(), &[first_gpu]); assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); } - -#[test] -fn test_specific_gpu_selection() { - let config = ConfigBuilder::default().build(); - let keys = ClientKey::generate(config); - let compressed_server_keys = CompressedServerKey::new(&keys); - - let mut rng = rand::thread_rng(); - - let total_gpus = get_number_of_gpus() as usize; - let num_gpus_to_use = rng.gen_range(1..=get_number_of_gpus()) as usize; - - // Randomly sample num_gpus_to_use indices - let selected_indices = rand::seq::index::sample(&mut rng, total_gpus, num_gpus_to_use); - - // Convert the selected indices to GpuIndex objects - let gpus_to_be_used = CustomMultiGpuIndexes::new( - selected_indices - .iter() - .map(|idx| GpuIndex::new(idx as u32)) - .collect(), - ); - - let clear_a: u32 = rng.gen(); - let clear_b: u32 = rng.gen(); - - let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap(); - let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap(); - - assert_eq!(a.current_device(), Device::Cpu); - assert_eq!(b.current_device(), Device::Cpu); - assert_eq!(a.gpu_indexes(), &[]); - assert_eq!(b.gpu_indexes(), &[]); - - let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used); - - let first_gpu = GpuIndex::new(0); - - set_server_key(cuda_key); - let c = &a + &b; - let decrypted: u32 = c.decrypt(&keys); - assert_eq!(c.current_device(), Device::CudaGpu); - assert_eq!(c.gpu_indexes(), &[first_gpu]); - assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); - - // Check explicit move, but first make sure input are on Cpu still - assert_eq!(a.current_device(), Device::Cpu); - assert_eq!(b.current_device(), Device::Cpu); - assert_eq!(a.gpu_indexes(), &[]); - assert_eq!(b.gpu_indexes(), &[]); - - a.move_to_current_device(); - b.move_to_current_device(); - - assert_eq!(a.current_device(), Device::CudaGpu); - assert_eq!(b.current_device(), Device::CudaGpu); - assert_eq!(a.gpu_indexes(), &[first_gpu]); - assert_eq!(b.gpu_indexes(), &[first_gpu]); - - let c = &a + &b; - let decrypted: u32 = c.decrypt(&keys); - assert_eq!(c.current_device(), Device::CudaGpu); - assert_eq!(c.gpu_indexes(), &[first_gpu]); - assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); -}