From 0280dbeb41c1b339a01948d57595f9cd44d8c77c Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Tue, 6 May 2025 12:04:16 -0300 Subject: [PATCH] feat(gpu): enables the user to perform computation on multi-gpu using a custom selection of GPUs --- .../cuda/include/helper_multi_gpu.h | 2 +- .../cuda/src/utils/helper_multi_gpu.cu | 11 +- .../tests/test_keyswitch.cpp | 2 +- backends/tfhe-cuda-backend/src/cuda_bind.rs | 2 +- tfhe/src/core_crypto/gpu/mod.rs | 31 +- tfhe/src/high_level_api/array/gpu/booleans.rs | 105 +++---- tfhe/src/high_level_api/array/gpu/integers.rs | 31 +- tfhe/src/high_level_api/booleans/base.rs | 107 ++++--- tfhe/src/high_level_api/booleans/encrypt.rs | 7 +- tfhe/src/high_level_api/booleans/oprf.rs | 7 +- .../compressed_ciphertext_list.rs | 23 +- tfhe/src/high_level_api/global_state.rs | 82 ++++-- tfhe/src/high_level_api/integers/oprf.rs | 22 +- .../high_level_api/integers/signed/base.rs | 62 ++-- .../high_level_api/integers/signed/encrypt.rs | 7 +- .../src/high_level_api/integers/signed/ops.rs | 272 +++++++++--------- .../integers/signed/overflowing_ops.rs | 22 +- .../integers/signed/scalar_ops.rs | 220 +++++++------- .../high_level_api/integers/unsigned/base.rs | 67 +++-- .../integers/unsigned/encrypt.rs | 7 +- .../high_level_api/integers/unsigned/ops.rs | 186 ++++++------ .../integers/unsigned/overflowing_ops.rs | 17 +- .../integers/unsigned/scalar_ops.rs | 137 +++++---- tfhe/src/high_level_api/keys/server.rs | 2 + .../src/high_level_api/tests/gpu_selection.rs | 66 +++++ 25 files changed, 816 insertions(+), 681 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h index 788fe416b..b25eb7e7e 100644 --- a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h +++ b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h @@ -8,7 +8,7 @@ extern std::mutex m; extern bool p2p_enabled; extern "C" { -int32_t cuda_setup_multi_gpu(); +int32_t cuda_setup_multi_gpu(int device_0_id); } // Define a variant type that can be either a vector or a single pointer diff --git a/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu b/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu index c7f0fdeab..97bd04f10 100644 --- a/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu +++ b/backends/tfhe-cuda-backend/cuda/src/utils/helper_multi_gpu.cu @@ -6,7 +6,8 @@ std::mutex m; bool p2p_enabled = false; -int32_t cuda_setup_multi_gpu() { +// Enable bidirectional p2p access between all available GPUs and device_0_id +int32_t cuda_setup_multi_gpu(int device_0_id) { int num_gpus = cuda_get_number_of_gpus(); if (num_gpus == 0) PANIC("GPU error: the number of GPUs should be > 0.") @@ -18,11 +19,13 @@ int32_t cuda_setup_multi_gpu() { omp_set_nested(1); int has_peer_access_to_device_0; for (int i = 1; i < num_gpus; i++) { - check_cuda_error( - cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, i, 0)); + check_cuda_error(cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, + i, device_0_id)); if (has_peer_access_to_device_0) { cuda_set_device(i); - check_cuda_error(cudaDeviceEnablePeerAccess(0, 0)); + check_cuda_error(cudaDeviceEnablePeerAccess(device_0_id, 0)); + cuda_set_device(device_0_id); + check_cuda_error(cudaDeviceEnablePeerAccess(i, 0)); } num_used_gpus += 1; } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp index 7fb1f7533..3fea7138d 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_keyswitch.cpp @@ -65,7 +65,7 @@ public: number_of_inputs = (int)GetParam().number_of_inputs; // Enable Multi-GPU logic - gpu_count = cuda_setup_multi_gpu(); + gpu_count = cuda_setup_multi_gpu(0); active_gpu_count = std::min((uint)number_of_inputs, gpu_count); for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) { streams.push_back(cuda_create_stream(gpu_i)); diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 31e04d12b..114f2dc05 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -101,6 +101,6 @@ extern "C" { pub fn cuda_drop_async(ptr: *mut c_void, stream: *mut c_void, gpu_index: u32); - pub fn cuda_setup_multi_gpu() -> i32; + pub fn cuda_setup_multi_gpu(gpu_index: u32) -> i32; } // extern "C" diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 100e72b97..d454bc9cd 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -17,7 +17,6 @@ pub use entities::*; use std::ffi::c_void; use tfhe_cuda_backend::bindings::*; use tfhe_cuda_backend::cuda_bind::*; - pub struct CudaStreams { pub ptr: Vec<*mut c_void>, pub gpu_indexes: Vec, @@ -30,7 +29,7 @@ unsafe impl Sync for CudaStreams {} impl CudaStreams { /// Create a new `CudaStreams` structure with as many GPUs as there are on the machine pub fn new_multi_gpu() -> Self { - let gpu_count = setup_multi_gpu(); + let gpu_count = setup_multi_gpu(GpuIndex::new(0)); let mut gpu_indexes = Vec::with_capacity(gpu_count as usize); let mut ptr_array = Vec::with_capacity(gpu_count as usize); @@ -43,6 +42,22 @@ impl CudaStreams { gpu_indexes, } } + /// Create a new `CudaStreams` structure with the GPUs with id provided in a list + pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self { + let _gpu_count = setup_multi_gpu(indexes[0]); + + let mut gpu_indexes = Vec::with_capacity(indexes.len()); + let mut ptr_array = Vec::with_capacity(indexes.len()); + + for &i in indexes { + ptr_array.push(unsafe { cuda_create_stream(i.get()) }); + gpu_indexes.push(i); + } + Self { + ptr: ptr_array, + gpu_indexes, + } + } /// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given /// as input pub fn new_single_gpu(gpu_index: GpuIndex) -> Self { @@ -88,6 +103,14 @@ impl CudaStreams { } } +impl Clone for CudaStreams { + fn clone(&self) -> Self { + // The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of + // streams being cloned (single, multi, or custom) + Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice()) + } +} + impl Drop for CudaStreams { fn drop(&mut self) { for (i, &s) in self.ptr.iter().enumerate() { @@ -1036,8 +1059,8 @@ pub fn get_number_of_gpus() -> u32 { } /// Setup multi-GPU and return the number of GPUs used -pub fn setup_multi_gpu() -> u32 { - unsafe { cuda_setup_multi_gpu() as u32 } +pub fn setup_multi_gpu(device_0_id: GpuIndex) -> u32 { + unsafe { cuda_setup_multi_gpu(device_0_id.get()) as u32 } } /// Synchronize device diff --git a/tfhe/src/high_level_api/array/gpu/booleans.rs b/tfhe/src/high_level_api/array/gpu/booleans.rs index 7d0b9cf17..a858d1e3c 100644 --- a/tfhe/src/high_level_api/array/gpu/booleans.rs +++ b/tfhe/src/high_level_api/array/gpu/booleans.rs @@ -156,14 +156,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams)) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams)) + }) + .collect::>() })) } @@ -172,14 +171,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams)) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams)) + }) + .collect::>() })) } @@ -188,24 +186,22 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, Self::Slice<'a>>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams)) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams)) + }) + .collect::>() })) } fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams))) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams))) + .collect::>() })) } } @@ -216,16 +212,13 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock( - cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams), - ) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams)) + }) + .collect::>() })) } @@ -234,16 +227,13 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock( - cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams), - ) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams)) + }) + .collect::>() })) } @@ -252,16 +242,13 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { rhs: TensorSlice<'_, &'_ [bool]>, ) -> Self::Owned { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter().copied()) - .map(|(lhs, rhs)| { - CudaBooleanBlock( - cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams), - ) - }) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter().copied()) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams)) + }) + .collect::>() })) } } diff --git a/tfhe/src/high_level_api/array/gpu/integers.rs b/tfhe/src/high_level_api/array/gpu/integers.rs index 1cf12a260..6a2b6ed84 100644 --- a/tfhe/src/high_level_api/array/gpu/integers.rs +++ b/tfhe/src/high_level_api/array/gpu/integers.rs @@ -108,12 +108,11 @@ where F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T, { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams)) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams)) + .collect::>() })) } @@ -170,12 +169,11 @@ where F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T, { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams)) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .zip(rhs.par_iter()) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams)) + .collect::>() })) } @@ -336,11 +334,10 @@ where fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned { GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { - with_thread_local_cuda_streams(|streams| { - lhs.par_iter() - .map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams)) - .collect::>() - }) + let streams = &cuda_key.streams; + lhs.par_iter() + .map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams)) + .collect::>() })) } } diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index db05aac24..2cba6f610 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -3,8 +3,6 @@ use crate::backward_compatibility::booleans::FheBoolVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric}; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId}; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{FheEq, IfThenElse, ScalarIfThenElse, Tagged}; @@ -385,7 +383,8 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool { (InnerBoolean::Cpu(new_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -394,7 +393,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool { ); let boolean_inner = CudaBooleanBlock(inner); (InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone()) - }), + } }); Self::new(ciphertext, tag) } @@ -422,7 +421,8 @@ where FheUint::new(inner, cpu_sks.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -431,7 +431,7 @@ where ); FheUint::new(inner, cuda_key.tag.clone()) - }), + } }) } } @@ -455,7 +455,8 @@ impl IfThenElse> for FheBool { FheInt::new(new_ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -464,7 +465,7 @@ impl IfThenElse> for FheBool { ); FheInt::new(inner, cuda_key.tag.clone()) - }), + } }) } } @@ -482,7 +483,8 @@ impl IfThenElse for FheBool { (InnerBoolean::Cpu(new_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.if_then_else( &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), &*ct_then.ciphertext.on_gpu(streams), @@ -491,7 +493,7 @@ impl IfThenElse for FheBool { ); let boolean_inner = CudaBooleanBlock(inner); (InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone()) - }), + } }); Self::new(ciphertext, tag) } @@ -541,7 +543,8 @@ where Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &other.borrow().ciphertext.on_gpu(streams), @@ -549,7 +552,7 @@ where ); let ciphertext = InnerBoolean::Cuda(inner); Self::new(ciphertext, cuda_key.tag.clone()) - }), + } }) } @@ -583,7 +586,8 @@ where Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &other.borrow().ciphertext.on_gpu(streams), @@ -591,7 +595,7 @@ where ); let ciphertext = InnerBoolean::Cuda(inner); Self::new(ciphertext, cuda_key.tag.clone()) - }), + } }) } } @@ -628,14 +632,15 @@ impl FheEq for FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.scalar_eq( &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) - }), + } }); Self::new(ciphertext, tag) } @@ -671,14 +676,15 @@ impl FheEq for FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.scalar_ne( &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) - }), + } }); Self::new(ciphertext, tag) } @@ -745,7 +751,8 @@ where (InnerBoolean::Cpu(inner_ct), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.bitand( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -758,7 +765,7 @@ where )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -830,7 +837,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.bitor( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -842,7 +850,7 @@ where )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -914,7 +922,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.bitxor( &*self.ciphertext.on_gpu(streams), &rhs.borrow().ciphertext.on_gpu(streams), @@ -926,7 +935,7 @@ where )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -990,7 +999,8 @@ impl BitAnd for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.scalar_bitand( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1002,7 +1012,7 @@ impl BitAnd for &FheBool { )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -1066,7 +1076,8 @@ impl BitOr for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.scalar_bitor( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1078,7 +1089,7 @@ impl BitOr for &FheBool { )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -1142,7 +1153,8 @@ impl BitXor for &FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_ct = cuda_key.key.key.scalar_bitxor( &*self.ciphertext.on_gpu(streams), u8::from(rhs), @@ -1154,7 +1166,7 @@ impl BitXor for &FheBool { )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } @@ -1346,13 +1358,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitand_assign( self.ciphertext.as_gpu_mut(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }); } } @@ -1389,13 +1402,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }); } } @@ -1432,13 +1446,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitxor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }); } } @@ -1469,13 +1484,14 @@ impl BitAndAssign for FheBool { .scalar_bitand_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitand_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - }), + } }); } } @@ -1506,13 +1522,14 @@ impl BitOrAssign for FheBool { .scalar_bitor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitor_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - }), + } }); } } @@ -1543,13 +1560,14 @@ impl BitXorAssign for FheBool { .scalar_bitxor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs)); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitxor_assign( self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); - }), + } }); } } @@ -1606,7 +1624,8 @@ impl std::ops::Not for &FheBool { (InnerBoolean::Cpu(inner), key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key .key @@ -1618,7 +1637,7 @@ impl std::ops::Not for &FheBool { )), cuda_key.tag.clone(), ) - }), + } }); FheBool::new(ciphertext, tag) } diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index c339a2b8d..56ef8b29b 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -1,8 +1,6 @@ use super::base::FheBool; use crate::high_level_api::booleans::inner::InnerBoolean; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; @@ -90,7 +88,8 @@ impl FheTryTrivialEncrypt for FheBool { (ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner: CudaUnsignedRadixCiphertext = cuda_key .key @@ -100,7 +99,7 @@ impl FheTryTrivialEncrypt for FheBool { inner.into_inner(), )); (ct, cuda_key.tag.clone()) - }), + } }); Ok(Self::new(ciphertext, tag)) } diff --git a/tfhe/src/high_level_api/booleans/oprf.rs b/tfhe/src/high_level_api/booleans/oprf.rs index 270ef554a..ebca57790 100644 --- a/tfhe/src/high_level_api/booleans/oprf.rs +++ b/tfhe/src/high_level_api/booleans/oprf.rs @@ -1,7 +1,5 @@ use super::{FheBool, InnerBoolean}; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; @@ -41,7 +39,8 @@ impl FheBool { ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -52,7 +51,7 @@ impl FheBool { )), cuda_key.tag.clone(), ) - }), + } }); Self::new(ciphertext, tag) } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 45849923c..cfc37216c 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -172,17 +172,16 @@ impl CompressedCiphertextListBuilder { for (element, _) in &self.inner { match element { ToBeCompressed::Cpu(cpu_blocks) => { - with_thread_local_cuda_streams(|streams| { - cuda_radixes.push(CudaRadixCiphertext::from_cpu_blocks( - cpu_blocks, streams, - )); - }) + let streams = &cuda_key.streams; + cuda_radixes + .push(CudaRadixCiphertext::from_cpu_blocks(cpu_blocks, streams)); } #[cfg(feature = "gpu")] ToBeCompressed::Cuda(cuda_radix) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; cuda_radixes.push(cuda_radix.duplicate(streams)); - }); + }; } } } @@ -195,10 +194,11 @@ impl CompressedCiphertextListBuilder { crate::Error::new("Compression key not set in server key".to_owned()) }) .map(|compression_key| { - let packed_list = with_thread_local_cuda_streams(|streams| { + let packed_list = { + let streams = &cuda_key.streams; compression_key .compress_ciphertexts_into_list(cuda_radixes.as_slice(), streams) - }); + }; let info = self.inner.iter().map(|(_, kind)| *kind).collect(); let compressed_list = CudaCompressedCiphertextList { packed_list, info }; @@ -458,11 +458,12 @@ impl CiphertextList for CompressedCiphertextList { crate::Error::new("Compression key not set in server key".to_owned()) }) .and_then(|decompression_key| { - let mut ct = with_thread_local_cuda_streams(|streams| { + let mut ct = { + let streams = &cuda_key.streams; self.inner .on_gpu(streams) .get::(index, decompression_key, streams) - }); + }; if let Ok(Some(ct_ref)) = &mut ct { ct_ref.tag_mut().set_data(cuda_key.tag.data()) } diff --git a/tfhe/src/high_level_api/global_state.rs b/tfhe/src/high_level_api/global_state.rs index f4f3e9a3d..52f3b1896 100644 --- a/tfhe/src/high_level_api/global_state.rs +++ b/tfhe/src/high_level_api/global_state.rs @@ -201,12 +201,15 @@ pub(in crate::high_level_api) use gpu::{ #[cfg(feature = "gpu")] pub use gpu::CudaGpuChoice; +#[cfg(feature = "gpu")] +pub struct CustomMultiGpuIndexes(Vec); #[cfg(feature = "gpu")] mod gpu { use crate::core_crypto::gpu::get_number_of_gpus; use super::*; + use itertools::Itertools; use std::cell::LazyCell; thread_local! { @@ -223,14 +226,14 @@ mod gpu { } struct CudaStreamPool { - multi: LazyCell, + custom: Option, single: Vec CudaStreams>>>, } impl CudaStreamPool { fn new() -> Self { Self { - multi: LazyCell::new(CudaStreams::new_multi_gpu), + custom: None, single: (0..get_number_of_gpus()) .map(|index| { let ctor = @@ -242,29 +245,6 @@ mod gpu { } } - impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool { - type Output = CudaStreams; - - fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output { - match indexes.len() { - 0 => panic!("Internal error: Gpu indexes must not be empty"), - 1 => &self.single[indexes[0].get() as usize], - _ => &self.multi, - } - } - } - - impl std::ops::Index for CudaStreamPool { - type Output = CudaStreams; - - fn index(&self, choice: CudaGpuChoice) -> &Self::Output { - match choice { - CudaGpuChoice::Multi => &self.multi, - CudaGpuChoice::Single(index) => &self.single[index.get() as usize], - } - } - } - pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes< R, F: for<'a> FnOnce(&'a CudaStreams) -> R, @@ -275,15 +255,44 @@ mod gpu { thread_local! { static POOL: RefCell = RefCell::new(CudaStreamPool::new()); } - POOL.with_borrow(|stream_pool| { - let stream = &stream_pool[gpu_indexes]; - func(stream) - }) + + if gpu_indexes.len() == 1 { + POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize])) + } else { + POOL.with_borrow_mut(|pool| match &pool.custom { + Some(streams) if streams.gpu_indexes != gpu_indexes => { + pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes)); + } + None => { + pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes)); + } + _ => {} + }); + + POOL.with_borrow(|pool| func(pool.custom.as_ref().unwrap())) + } } - #[derive(Copy, Clone)] + + impl Clone for CustomMultiGpuIndexes { + fn clone(&self) -> Self { + self.0.iter().copied().collect_vec().into() + } + } + + impl CustomMultiGpuIndexes { + pub fn new(indexes: Vec) -> Self { + Self(indexes) + } + pub fn gpu_indexes(&self) -> &[GpuIndex] { + self.0.as_slice() + } + } + + #[derive(Clone)] pub enum CudaGpuChoice { Single(GpuIndex), Multi, + Custom(CustomMultiGpuIndexes), } impl From for CudaGpuChoice { @@ -292,11 +301,24 @@ mod gpu { } } + impl From> for CustomMultiGpuIndexes { + fn from(value: Vec) -> Self { + Self(value) + } + } + + impl From for CudaGpuChoice { + fn from(values: CustomMultiGpuIndexes) -> Self { + Self::Custom(values) + } + } + impl CudaGpuChoice { pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams { match self { Self::Single(idx) => CudaStreams::new_single_gpu(idx), Self::Multi => CudaStreams::new_multi_gpu(), + Self::Custom(idxs) => CudaStreams::new_multi_gpu_with_indexes(idxs.gpu_indexes()), } } } diff --git a/tfhe/src/high_level_api/integers/oprf.rs b/tfhe/src/high_level_api/integers/oprf.rs index 1c40f5e70..4280a5994 100644 --- a/tfhe/src/high_level_api/integers/oprf.rs +++ b/tfhe/src/high_level_api/integers/oprf.rs @@ -1,7 +1,5 @@ use super::{FheIntId, FheUint, FheUintId}; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; @@ -38,7 +36,8 @@ impl FheUint { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -49,7 +48,7 @@ impl FheUint { ); Self::new(d_ct, cuda_key.tag.clone()) - }), + } }) } /// Generates an encrypted `num_block` blocks unsigned integer @@ -87,7 +86,8 @@ impl FheUint { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let d_ct: CudaUnsignedRadixCiphertext = cuda_key .key .key @@ -98,7 +98,7 @@ impl FheUint { streams, ); Self::new(d_ct, cuda_key.tag.clone()) - }), + } }) } } @@ -136,7 +136,8 @@ impl FheInt { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let d_ct: CudaSignedRadixCiphertext = cuda_key .key .key @@ -147,7 +148,7 @@ impl FheInt { ); Self::new(d_ct, cuda_key.tag.clone()) - }), + } }) } @@ -187,7 +188,8 @@ impl FheInt { Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let d_ct: CudaSignedRadixCiphertext = cuda_key .key .key @@ -198,7 +200,7 @@ impl FheInt { streams, ); Self::new(d_ct, cuda_key.tag.clone()) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 675b19e94..5f28ed35f 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -17,8 +17,6 @@ use crate::shortint::PBSParameters; use crate::{Device, FheBool, ServerKey, Tag}; use std::marker::PhantomData; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; pub trait FheIntId: IntegerId {} /// A Generic FHE signed integer @@ -197,13 +195,14 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key .abs(&*self.ciphertext.on_gpu(streams), streams); Self::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -233,13 +232,14 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -269,13 +269,14 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -309,7 +310,8 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -320,7 +322,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -354,7 +356,8 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -365,7 +368,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -399,7 +402,8 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -410,7 +414,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -444,7 +448,8 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -455,7 +460,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -565,7 +570,8 @@ where crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -576,7 +582,7 @@ where streams, ); crate::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -619,7 +625,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, is_ok) = cuda_key .key .key @@ -633,7 +640,7 @@ where crate::FheUint32::new(result, cuda_key.tag.clone()), FheBool::new(is_ok, cuda_key.tag.clone()), ) - }), + } }) } @@ -807,7 +814,8 @@ where Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus()); let new_ciphertext = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams), @@ -815,7 +823,7 @@ where streams, ); Self::new(new_ciphertext, cuda_key.tag.clone()) - }), + } }) } } @@ -852,14 +860,15 @@ where Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let new_ciphertext = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(new_ciphertext, cuda_key.tag.clone()) - }), + } }) } } @@ -899,14 +908,15 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.cast_to_signed( input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(inner, cuda_key.tag.clone()) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/signed/encrypt.rs b/tfhe/src/high_level_api/integers/signed/encrypt.rs index d4e284fd5..2f2a3fc86 100644 --- a/tfhe/src/high_level_api/integers/signed/encrypt.rs +++ b/tfhe/src/high_level_api/integers/signed/encrypt.rs @@ -1,7 +1,5 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger}; @@ -113,14 +111,15 @@ where Ok(Self::new(ciphertext, key.tag.clone())) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( value, Id::num_blocks(cuda_key.key.key.message_modulus), streams, ); Ok(Self::new(inner, cuda_key.tag.clone())) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index ebac40a4d..07497d6b1 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -18,9 +18,6 @@ use std::ops::{ Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; - impl<'a, Id> std::iter::Sum<&'a Self> for FheInt where Id: FheIntId, @@ -74,7 +71,8 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; let cts = iter .map(|fhe_uint| { match fhe_uint.ciphertext.on_gpu(streams) { @@ -104,7 +102,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - }) + } } }) } @@ -144,14 +142,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.max( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -190,14 +189,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.min( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -247,14 +247,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -286,14 +287,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -351,14 +353,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.lt( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -390,14 +393,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.le( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -429,14 +433,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.gt( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -468,14 +473,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.ge( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -550,7 +556,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (q, r) = cuda_key.key.key.div_rem( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), @@ -560,7 +567,7 @@ where FheInt::::new(q, cuda_key.tag.clone()), FheInt::::new(r, cuda_key.tag.clone()), ) - }), + } }) } } @@ -634,11 +641,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -677,11 +684,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -720,11 +727,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -761,11 +768,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -802,11 +809,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -843,11 +850,11 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } }) } @@ -891,14 +898,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }), + }, }) } }, @@ -942,14 +949,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }), + }, }) } }, @@ -1057,11 +1064,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } } }) @@ -1101,11 +1108,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } } }) @@ -1145,11 +1152,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } } }) @@ -1189,11 +1196,11 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }) + } } } }) @@ -1239,13 +1246,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.add_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.add_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1286,13 +1292,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.sub_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.sub_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1333,13 +1338,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.mul_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.mul_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1378,13 +1382,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.bitand_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.bitand_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1423,13 +1426,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.bitor_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.bitor_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1468,13 +1470,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.bitxor_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }) + let streams = &cuda_key.streams; + cuda_key.key.key.bitxor_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1518,7 +1519,8 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; let cuda_lhs = self.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().div( &*cuda_lhs, @@ -1526,7 +1528,7 @@ where streams, ); *cuda_lhs = cuda_result; - }); + }; } }) } @@ -1570,15 +1572,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - let cuda_lhs = self.ciphertext.as_gpu_mut(streams); - let cuda_result = cuda_key.pbs_key().rem( - &*cuda_lhs, - &rhs.ciphertext.on_gpu(streams), - streams, - ); - *cuda_lhs = cuda_result; - }); + let streams = &cuda_key.streams; + let cuda_lhs = self.ciphertext.as_gpu_mut(streams); + let cuda_result = + cuda_key + .pbs_key() + .rem(&*cuda_lhs, &rhs.ciphertext.on_gpu(streams), streams); + *cuda_lhs = cuda_result; } }) } @@ -1627,13 +1627,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.left_shift_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }); + let streams = &cuda_key.streams; + cuda_key.key.key.left_shift_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1681,13 +1680,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.right_shift_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }); + let streams = &cuda_key.streams; + cuda_key.key.key.right_shift_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1736,13 +1734,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.rotate_left_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }); + let streams = &cuda_key.streams; + cuda_key.key.key.rotate_left_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1791,13 +1788,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }); + let streams = &cuda_key.streams; + cuda_key.key.key.rotate_right_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1863,13 +1859,14 @@ where FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .neg(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -1932,13 +1929,14 @@ where FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -1956,13 +1954,14 @@ where InternalServerKey::Cpu(_) => { tmp_buffer_size = 0; } - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }); tmp_buffer_size } @@ -1975,12 +1974,13 @@ where { fn get_size_on_gpu(&self) -> u64 { global_state::with_internal_keys(|key| match key { - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key .key .key .get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams)) - }), + } InternalServerKey::Cpu(_) => 0, }) } diff --git a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs index 840d1c443..01c2b34ce 100644 --- a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs @@ -1,7 +1,5 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::DecomposableInto; @@ -53,7 +51,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, overflow) = cuda_key.key.key.signed_overflowing_add( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -63,7 +62,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - }), + } }) } } @@ -149,7 +148,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add( &self.ciphertext.on_gpu(streams), other, @@ -159,7 +159,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - }), + } }) } } @@ -283,7 +283,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, overflow) = cuda_key.key.key.signed_overflowing_sub( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -293,7 +294,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - }), + } }) } } @@ -378,7 +379,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub( &self.ciphertext.on_gpu(streams), other, @@ -388,7 +390,7 @@ where FheInt::new(result, cuda_key.tag.clone()), FheBool::new(overflow, cuda_key.tag.clone()), ) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index 6dc5aa41c..311e888a3 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -2,8 +2,6 @@ use crate::core_crypto::commons::numeric::CastFrom; use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext; use crate::high_level_api::integers::FheIntId; use crate::high_level_api::keys::InternalServerKey; @@ -55,14 +53,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.scalar_max( - &*self.ciphertext.on_gpu(streams), - rhs, - streams, - ); - Self::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams); + Self::new(inner_result, cuda_key.tag.clone()) } }) } @@ -103,14 +100,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.scalar_min( - &*self.ciphertext.on_gpu(streams), - rhs, - streams, - ); - Self::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams); + Self::new(inner_result, cuda_key.tag.clone()) } }) } @@ -150,14 +146,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -191,14 +186,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -237,14 +231,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -277,14 +270,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -317,14 +309,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -357,14 +348,13 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); - FheBool::new(inner_result, cuda_key.tag.clone()) - }) + let streams = &cuda_key.streams; + let inner_result = + cuda_key + .key + .key + .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); + FheBool::new(inner_result, cuda_key.tag.clone()) } }) } @@ -408,11 +398,11 @@ macro_rules! generic_integer_impl_scalar_div_rem { } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + let (inner_q, inner_r) = {let streams = &cuda_key.streams; cuda_key.key.key.signed_scalar_div_rem( &*self.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; let (q, r) = ( SignedRadixCiphertext::Cuda(inner_q), SignedRadixCiphertext::Cuda(inner_r), @@ -456,11 +446,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_left_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -485,11 +475,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_right_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -514,11 +504,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_rotate_left( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -543,11 +533,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_rotate_right( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -570,11 +560,10 @@ macro_rules! define_scalar_rotate_shifts { .scalar_left_shift_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => + {let streams = &cuda_key.streams; cuda_key.key.key .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -597,10 +586,9 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -622,11 +610,10 @@ macro_rules! define_scalar_rotate_shifts { .scalar_rotate_left_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => + {let streams = &cuda_key.streams; cuda_key.key.key .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -649,10 +636,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; cuda_key.key.key .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) + } } }) } @@ -762,11 +749,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_add( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -791,11 +778,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_sub( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -820,11 +807,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_mul( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -850,11 +837,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitand( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -879,11 +866,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -909,11 +896,11 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitxor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -938,11 +925,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.signed_scalar_div( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -968,11 +955,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.signed_scalar_rem( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; SignedRadixCiphertext::Cuda(inner_result) } }) @@ -1014,12 +1001,11 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let mut result: CudaSignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams); SignedRadixCiphertext::Cuda(result) - }) } }) } @@ -1110,10 +1096,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1141,10 +1126,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1168,10 +1152,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1196,10 +1179,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1223,10 +1205,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1249,10 +1230,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1274,11 +1254,11 @@ macro_rules! define_scalar_ops { .signed_scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().signed_scalar_div(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - }) + } }) } }, @@ -1299,11 +1279,11 @@ macro_rules! define_scalar_ops { .signed_scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().signed_scalar_rem(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - }) + } }) } }, diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 972ff5423..00b8f35d6 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -4,8 +4,6 @@ use super::inner::RadixCiphertext; use crate::backward_compatibility::integers::FheUintVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric}; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::signed::{FheInt, FheIntId}; use crate::high_level_api::integers::IntegerId; use crate::high_level_api::keys::InternalServerKey; @@ -249,13 +247,14 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -285,13 +284,14 @@ where FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -418,7 +418,8 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -429,7 +430,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -463,7 +464,8 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -474,7 +476,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -508,7 +510,8 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -519,7 +522,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -553,7 +556,8 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -564,7 +568,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -674,7 +678,8 @@ where super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key .key .key @@ -685,7 +690,7 @@ where streams, ); super::FheUint32::new(result, cuda_key.tag.clone()) - }), + } }) } @@ -728,7 +733,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, is_ok) = cuda_key .key .key @@ -742,7 +748,7 @@ where super::FheUint32::new(result, cuda_key.tag.clone()), FheBool::new(is_ok, cuda_key.tag.clone()), ) - }), + } }) } @@ -810,7 +816,8 @@ where } } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let (result, matched) = cuda_key.key.key.match_value( &self.ciphertext.on_gpu(streams), matches, @@ -825,7 +832,7 @@ where } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } - }), + } }) } @@ -888,7 +895,8 @@ where } } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let result = cuda_key.key.key.match_value_or( &self.ciphertext.on_gpu(streams), matches, @@ -901,7 +909,7 @@ where } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } - }), + } }) } @@ -1101,14 +1109,15 @@ where Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let casted = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(casted, cuda_key.tag.clone()) - }), + } }) } } @@ -1145,14 +1154,15 @@ where Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let casted = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(casted, cuda_key.tag.clone()) - }), + } }) } } @@ -1189,14 +1199,15 @@ where Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner = cuda_key.key.key.cast_to_unsigned( input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); Self::new(inner, cuda_key.tag.clone()) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index 593e580e6..135cf8c0a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -1,7 +1,5 @@ use crate::core_crypto::prelude::UnsignedNumeric; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; @@ -115,14 +113,15 @@ where Ok(Self::new(ciphertext, key.tag.clone())) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( value, Id::num_blocks(cuda_key.key.key.message_modulus), streams, ); Ok(Self::new(inner, cuda_key.tag.clone())) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 1efe9cce8..38d54179d 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -5,8 +5,6 @@ use super::inner::RadixCiphertext; #[cfg(feature = "gpu")] use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; #[cfg(feature = "gpu")] @@ -75,7 +73,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let cts = iter .map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams)) .collect::>(); @@ -92,7 +91,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - }), + } }) } } @@ -154,7 +153,8 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; let cts = iter .map(|fhe_uint| { match fhe_uint.ciphertext.on_gpu(streams) { @@ -184,7 +184,7 @@ where ) }); Self::new(inner, cuda_key.tag.clone()) - }) + } } }) } @@ -224,14 +224,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.max( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -270,14 +271,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.min( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -327,14 +329,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.eq( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -366,14 +369,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.ne( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -431,14 +435,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.lt( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -470,14 +475,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.le( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -509,14 +515,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.gt( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -548,14 +555,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.ge( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -631,7 +639,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.div_rem( &*self.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), @@ -641,7 +650,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheUint::::new(inner_result.1, cuda_key.tag.clone()), ) - }), + } }) } } @@ -717,11 +726,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -760,11 +768,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -803,11 +810,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -844,11 +850,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -885,11 +890,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -926,11 +930,10 @@ generic_integer_impl_operation!( }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } }) } @@ -974,14 +977,15 @@ generic_integer_impl_operation!( FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }), + }, }) } }, @@ -1025,14 +1029,16 @@ generic_integer_impl_operation!( FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => + { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }), + }, }) } }, @@ -1140,11 +1146,10 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } } }) @@ -1184,11 +1189,10 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } } }) @@ -1228,11 +1232,10 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } } }) @@ -1272,11 +1275,10 @@ generic_integer_impl_shift_rotate!( } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }) } } }) @@ -1321,13 +1323,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.add_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1366,13 +1369,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.sub_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1411,13 +1415,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.mul_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1454,13 +1459,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitand_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1497,13 +1503,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1540,13 +1547,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.bitxor_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1588,13 +1596,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.div_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1636,13 +1645,14 @@ where ); } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key.key.key.rem_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }) } } @@ -1690,13 +1700,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; cuda_key.key.key.left_shift_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }); + }; } }) } @@ -1744,13 +1755,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; cuda_key.key.key.right_shift_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }); + }; } }) } @@ -1799,13 +1811,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + { + let streams = &cuda_key.streams; cuda_key.key.key.rotate_left_assign( self.ciphertext.as_gpu_mut(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }); + }; } }) } @@ -1854,13 +1867,12 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { - cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(streams), - &rhs.ciphertext.on_gpu(streams), - streams, - ); - }); + let streams = &cuda_key.streams; + cuda_key.key.key.rotate_right_assign( + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), + streams, + ); } }) } @@ -1934,13 +1946,14 @@ where FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .neg(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -2003,13 +2016,14 @@ where FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -2026,13 +2040,14 @@ where InternalServerKey::Cpu(_) => { tmp_buffer_size = 0; } - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu( &*self.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams, ); - }), + } }); tmp_buffer_size } @@ -2044,12 +2059,13 @@ where { fn get_size_on_gpu(&self) -> u64 { global_state::with_internal_keys(|key| match key { - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; cuda_key .key .key .get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams)) - }), + } InternalServerKey::Cpu(_) => 0, }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs index 9b2e274d4..74e803263 100644 --- a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs @@ -1,7 +1,5 @@ use crate::core_crypto::prelude::UnsignedNumeric; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::DecomposableInto; @@ -53,7 +51,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.unsigned_overflowing_add( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -63,7 +62,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - }), + } }) } } @@ -149,7 +148,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add( &self.ciphertext.on_gpu(streams), other, @@ -159,7 +159,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - }), + } }) } } @@ -285,7 +285,8 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key.key.key.unsigned_overflowing_sub( &self.ciphertext.on_gpu(streams), &other.ciphertext.on_gpu(streams), @@ -295,7 +296,7 @@ where FheUint::::new(inner_result.0, cuda_key.tag.clone()), FheBool::new(inner_result.1, cuda_key.tag.clone()), ) - }), + } }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 6f93356ef..0e9150bb8 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -7,8 +7,6 @@ use super::inner::RadixCiphertext; use crate::error::InvalidRangeError; use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; -#[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{ @@ -60,14 +58,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -99,14 +98,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -144,14 +144,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -183,14 +184,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -222,14 +224,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } @@ -261,14 +264,15 @@ where FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -308,14 +312,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -355,14 +360,15 @@ where Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; let inner_result = cuda_key .key .key .scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) - }), + } }) } } @@ -493,11 +499,11 @@ macro_rules! generic_integer_impl_scalar_div_rem { } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + let (inner_q, inner_r) = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_div_rem( &*self.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); ( <$concrete_type>::new(q, cuda_key.tag.clone()), @@ -673,11 +679,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_left_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -702,11 +708,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_right_shift( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -731,11 +737,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_rotate_left( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -760,11 +766,11 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_rotate_right( &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -788,10 +794,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; cuda_key.key.key .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) + } } }) } @@ -814,10 +820,10 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + {let streams = &cuda_key.streams; cuda_key.key.key .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) + } } }) } @@ -840,10 +846,9 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -866,10 +871,9 @@ macro_rules! define_scalar_rotate_shifts { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -980,11 +984,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_add( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1009,11 +1013,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_sub( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1038,11 +1042,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_mul( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1068,11 +1072,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitand( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1097,11 +1101,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1127,11 +1131,11 @@ macro_rules! define_scalar_ops { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_bitxor( &*lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1156,11 +1160,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_div( &lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1186,11 +1190,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - let inner_result = with_thread_local_cuda_streams(|streams| { + let inner_result = {let streams = &cuda_key.streams; cuda_key.key.key.scalar_rem( &lhs.ciphertext.on_gpu(streams), rhs, streams ) - }); + }; RadixCiphertext::Cuda(inner_result) } }) @@ -1231,12 +1235,11 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); RadixCiphertext::Cuda(result) - }) } }) } @@ -1327,10 +1330,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1358,10 +1360,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1385,10 +1386,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1413,10 +1413,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1440,10 +1439,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1466,10 +1464,9 @@ macro_rules! define_scalar_ops { }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { - with_thread_local_cuda_streams(|streams| { + let streams = &cuda_key.streams; cuda_key.key.key .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); - }) } }) } @@ -1491,11 +1488,11 @@ macro_rules! define_scalar_ops { .scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().scalar_div(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - }) + } }) } }, @@ -1516,11 +1513,11 @@ macro_rules! define_scalar_ops { .scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| { + InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams; let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams); let cuda_result = cuda_key.pbs_key().scalar_rem(&cuda_lhs, rhs, streams); *cuda_lhs = cuda_result; - }) + } }) } }, diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 26bb0703e..8cc8f0dd4 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -328,6 +328,7 @@ impl CompressedServerKey { decompression_key, }), tag: self.tag.clone(), + streams, } } } @@ -351,6 +352,7 @@ impl Named for CompressedServerKey { pub struct CudaServerKey { pub(crate) key: Arc, pub(crate) tag: Tag, + pub(crate) streams: CudaStreams, } #[cfg(feature = "gpu")] diff --git a/tfhe/src/high_level_api/tests/gpu_selection.rs b/tfhe/src/high_level_api/tests/gpu_selection.rs index c1d1e8b49..e22caa1c1 100644 --- a/tfhe/src/high_level_api/tests/gpu_selection.rs +++ b/tfhe/src/high_level_api/tests/gpu_selection.rs @@ -1,6 +1,7 @@ use rand::Rng; use crate::core_crypto::gpu::get_number_of_gpus; +use crate::high_level_api::global_state::CustomMultiGpuIndexes; use crate::prelude::*; use crate::{ set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex, @@ -117,3 +118,68 @@ fn test_gpu_selection_2() { assert_eq!(c.gpu_indexes(), &[first_gpu]); assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); } + +#[test] +fn test_specific_gpu_selection() { + let config = ConfigBuilder::default().build(); + let keys = ClientKey::generate(config); + let compressed_server_keys = CompressedServerKey::new(&keys); + + let mut rng = rand::thread_rng(); + + let total_gpus = get_number_of_gpus() as usize; + let num_gpus_to_use = rng.gen_range(1..=get_number_of_gpus()) as usize; + + // Randomly sample num_gpus_to_use indices + let selected_indices = rand::seq::index::sample(&mut rng, total_gpus, num_gpus_to_use); + + // Convert the selected indices to GpuIndex objects + let gpus_to_be_used = CustomMultiGpuIndexes::new( + selected_indices + .iter() + .map(|idx| GpuIndex::new(idx as u32)) + .collect(), + ); + + let clear_a: u32 = rng.gen(); + let clear_b: u32 = rng.gen(); + + let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap(); + let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap(); + + assert_eq!(a.current_device(), Device::Cpu); + assert_eq!(b.current_device(), Device::Cpu); + assert_eq!(a.gpu_indexes(), &[]); + assert_eq!(b.gpu_indexes(), &[]); + + let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used); + + let first_gpu = GpuIndex::new(0); + + set_server_key(cuda_key); + let c = &a + &b; + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[first_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); + + // Check explicit move, but first make sure input are on Cpu still + assert_eq!(a.current_device(), Device::Cpu); + assert_eq!(b.current_device(), Device::Cpu); + assert_eq!(a.gpu_indexes(), &[]); + assert_eq!(b.gpu_indexes(), &[]); + + a.move_to_current_device(); + b.move_to_current_device(); + + assert_eq!(a.current_device(), Device::CudaGpu); + assert_eq!(b.current_device(), Device::CudaGpu); + assert_eq!(a.gpu_indexes(), &[first_gpu]); + assert_eq!(b.gpu_indexes(), &[first_gpu]); + + let c = &a + &b; + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[first_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); +}