mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(gpu): revert enables the user to perform computation on multi-gpu using a custom selection of GPUs
This reverts commit 0280dbeb41.
This commit is contained in:
@@ -8,7 +8,7 @@ extern std::mutex m;
|
||||
extern bool p2p_enabled;
|
||||
|
||||
extern "C" {
|
||||
int32_t cuda_setup_multi_gpu(int device_0_id);
|
||||
int32_t cuda_setup_multi_gpu();
|
||||
}
|
||||
|
||||
// Define a variant type that can be either a vector or a single pointer
|
||||
|
||||
@@ -6,8 +6,7 @@
|
||||
std::mutex m;
|
||||
bool p2p_enabled = false;
|
||||
|
||||
// Enable bidirectional p2p access between all available GPUs and device_0_id
|
||||
int32_t cuda_setup_multi_gpu(int device_0_id) {
|
||||
int32_t cuda_setup_multi_gpu() {
|
||||
int num_gpus = cuda_get_number_of_gpus();
|
||||
if (num_gpus == 0)
|
||||
PANIC("GPU error: the number of GPUs should be > 0.")
|
||||
@@ -19,13 +18,11 @@ int32_t cuda_setup_multi_gpu(int device_0_id) {
|
||||
omp_set_nested(1);
|
||||
int has_peer_access_to_device_0;
|
||||
for (int i = 1; i < num_gpus; i++) {
|
||||
check_cuda_error(cudaDeviceCanAccessPeer(&has_peer_access_to_device_0,
|
||||
i, device_0_id));
|
||||
check_cuda_error(
|
||||
cudaDeviceCanAccessPeer(&has_peer_access_to_device_0, i, 0));
|
||||
if (has_peer_access_to_device_0) {
|
||||
cuda_set_device(i);
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(device_0_id, 0));
|
||||
cuda_set_device(device_0_id);
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(i, 0));
|
||||
check_cuda_error(cudaDeviceEnablePeerAccess(0, 0));
|
||||
}
|
||||
num_used_gpus += 1;
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ public:
|
||||
number_of_inputs = (int)GetParam().number_of_inputs;
|
||||
|
||||
// Enable Multi-GPU logic
|
||||
gpu_count = cuda_setup_multi_gpu(0);
|
||||
gpu_count = cuda_setup_multi_gpu();
|
||||
active_gpu_count = std::min((uint)number_of_inputs, gpu_count);
|
||||
for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) {
|
||||
streams.push_back(cuda_create_stream(gpu_i));
|
||||
|
||||
@@ -101,6 +101,6 @@ extern "C" {
|
||||
|
||||
pub fn cuda_drop_async(ptr: *mut c_void, stream: *mut c_void, gpu_index: u32);
|
||||
|
||||
pub fn cuda_setup_multi_gpu(gpu_index: u32) -> i32;
|
||||
pub fn cuda_setup_multi_gpu() -> i32;
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -17,6 +17,7 @@ pub use entities::*;
|
||||
use std::ffi::c_void;
|
||||
use tfhe_cuda_backend::bindings::*;
|
||||
use tfhe_cuda_backend::cuda_bind::*;
|
||||
|
||||
pub struct CudaStreams {
|
||||
pub ptr: Vec<*mut c_void>,
|
||||
pub gpu_indexes: Vec<GpuIndex>,
|
||||
@@ -29,7 +30,7 @@ unsafe impl Sync for CudaStreams {}
|
||||
impl CudaStreams {
|
||||
/// Create a new `CudaStreams` structure with as many GPUs as there are on the machine
|
||||
pub fn new_multi_gpu() -> Self {
|
||||
let gpu_count = setup_multi_gpu(GpuIndex::new(0));
|
||||
let gpu_count = setup_multi_gpu();
|
||||
let mut gpu_indexes = Vec::with_capacity(gpu_count as usize);
|
||||
let mut ptr_array = Vec::with_capacity(gpu_count as usize);
|
||||
|
||||
@@ -42,22 +43,6 @@ impl CudaStreams {
|
||||
gpu_indexes,
|
||||
}
|
||||
}
|
||||
/// Create a new `CudaStreams` structure with the GPUs with id provided in a list
|
||||
pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self {
|
||||
let _gpu_count = setup_multi_gpu(indexes[0]);
|
||||
|
||||
let mut gpu_indexes = Vec::with_capacity(indexes.len());
|
||||
let mut ptr_array = Vec::with_capacity(indexes.len());
|
||||
|
||||
for &i in indexes {
|
||||
ptr_array.push(unsafe { cuda_create_stream(i.get()) });
|
||||
gpu_indexes.push(i);
|
||||
}
|
||||
Self {
|
||||
ptr: ptr_array,
|
||||
gpu_indexes,
|
||||
}
|
||||
}
|
||||
/// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given
|
||||
/// as input
|
||||
pub fn new_single_gpu(gpu_index: GpuIndex) -> Self {
|
||||
@@ -103,14 +88,6 @@ impl CudaStreams {
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for CudaStreams {
|
||||
fn clone(&self) -> Self {
|
||||
// The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of
|
||||
// streams being cloned (single, multi, or custom)
|
||||
Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaStreams {
|
||||
fn drop(&mut self) {
|
||||
for (i, &s) in self.ptr.iter().enumerate() {
|
||||
@@ -1059,8 +1036,8 @@ pub fn get_number_of_gpus() -> u32 {
|
||||
}
|
||||
|
||||
/// Setup multi-GPU and return the number of GPUs used
|
||||
pub fn setup_multi_gpu(device_0_id: GpuIndex) -> u32 {
|
||||
unsafe { cuda_setup_multi_gpu(device_0_id.get()) as u32 }
|
||||
pub fn setup_multi_gpu() -> u32 {
|
||||
unsafe { cuda_setup_multi_gpu() as u32 }
|
||||
}
|
||||
|
||||
/// Synchronize device
|
||||
|
||||
@@ -156,13 +156,14 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -171,13 +172,14 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -186,22 +188,24 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, Self::Slice<'a>>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -212,13 +216,16 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -227,13 +234,16 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -242,13 +252,16 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
|
||||
rhs: TensorSlice<'_, &'_ [bool]>,
|
||||
) -> Self::Owned {
|
||||
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter().copied())
|
||||
.map(|(lhs, rhs)| {
|
||||
CudaBooleanBlock(
|
||||
cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,11 +108,12 @@ where
|
||||
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T,
|
||||
{
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -169,11 +170,12 @@ where
|
||||
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T,
|
||||
{
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.zip(rhs.par_iter())
|
||||
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -334,10 +336,11 @@ where
|
||||
|
||||
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
|
||||
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
|
||||
let streams = &cuda_key.streams;
|
||||
lhs.par_iter()
|
||||
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
lhs.par_iter()
|
||||
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::backward_compatibility::booleans::FheBoolVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::traits::{FheEq, IfThenElse, ScalarIfThenElse, Tagged};
|
||||
@@ -383,8 +385,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
|
||||
(InnerBoolean::Cpu(new_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -393,7 +394,7 @@ impl ScalarIfThenElse<&Self, &Self> for FheBool {
|
||||
);
|
||||
let boolean_inner = CudaBooleanBlock(inner);
|
||||
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
});
|
||||
Self::new(ciphertext, tag)
|
||||
}
|
||||
@@ -421,8 +422,7 @@ where
|
||||
FheUint::new(inner, cpu_sks.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -431,7 +431,7 @@ where
|
||||
);
|
||||
|
||||
FheUint::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -455,8 +455,7 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
|
||||
FheInt::new(new_ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -465,7 +464,7 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
|
||||
);
|
||||
|
||||
FheInt::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -483,8 +482,7 @@ impl IfThenElse<Self> for FheBool {
|
||||
(InnerBoolean::Cpu(new_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.if_then_else(
|
||||
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
|
||||
&*ct_then.ciphertext.on_gpu(streams),
|
||||
@@ -493,7 +491,7 @@ impl IfThenElse<Self> for FheBool {
|
||||
);
|
||||
let boolean_inner = CudaBooleanBlock(inner);
|
||||
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
});
|
||||
Self::new(ciphertext, tag)
|
||||
}
|
||||
@@ -543,8 +541,7 @@ where
|
||||
Self::new(ciphertext, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&other.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -552,7 +549,7 @@ where
|
||||
);
|
||||
let ciphertext = InnerBoolean::Cuda(inner);
|
||||
Self::new(ciphertext, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -586,8 +583,7 @@ where
|
||||
Self::new(ciphertext, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&other.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -595,7 +591,7 @@ where
|
||||
);
|
||||
let ciphertext = InnerBoolean::Cuda(inner);
|
||||
Self::new(ciphertext, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -632,15 +628,14 @@ impl FheEq<bool> for FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.scalar_eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(other),
|
||||
streams,
|
||||
);
|
||||
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
});
|
||||
Self::new(ciphertext, tag)
|
||||
}
|
||||
@@ -676,15 +671,14 @@ impl FheEq<bool> for FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.scalar_ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(other),
|
||||
streams,
|
||||
);
|
||||
(InnerBoolean::Cuda(inner), cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
});
|
||||
Self::new(ciphertext, tag)
|
||||
}
|
||||
@@ -751,8 +745,7 @@ where
|
||||
(InnerBoolean::Cpu(inner_ct), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.bitand(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -765,7 +758,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -837,8 +830,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.bitor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -850,7 +842,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -922,8 +914,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.bitxor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.borrow().ciphertext.on_gpu(streams),
|
||||
@@ -935,7 +926,7 @@ where
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -999,8 +990,7 @@ impl BitAnd<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.scalar_bitand(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1012,7 +1002,7 @@ impl BitAnd<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -1076,8 +1066,7 @@ impl BitOr<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.scalar_bitor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1089,7 +1078,7 @@ impl BitOr<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -1153,8 +1142,7 @@ impl BitXor<bool> for &FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_ct = cuda_key.key.key.scalar_bitxor(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
u8::from(rhs),
|
||||
@@ -1166,7 +1154,7 @@ impl BitXor<bool> for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
@@ -1358,14 +1346,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1402,14 +1389,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1446,14 +1432,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1484,14 +1469,13 @@ impl BitAndAssign<bool> for FheBool {
|
||||
.scalar_bitand_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1522,14 +1506,13 @@ impl BitOrAssign<bool> for FheBool {
|
||||
.scalar_bitor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1560,14 +1543,13 @@ impl BitXorAssign<bool> for FheBool {
|
||||
.scalar_bitxor_assign(&mut self.ciphertext.as_cpu_mut().0, u8::from(rhs));
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
u8::from(rhs),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1624,8 +1606,7 @@ impl std::ops::Not for &FheBool {
|
||||
(InnerBoolean::Cpu(inner), key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner =
|
||||
cuda_key
|
||||
.key
|
||||
@@ -1637,7 +1618,7 @@ impl std::ops::Not for &FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
FheBool::new(ciphertext, tag)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use super::base::FheBool;
|
||||
use crate::high_level_api::booleans::inner::InnerBoolean;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
@@ -88,8 +90,7 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
|
||||
(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner: CudaUnsignedRadixCiphertext =
|
||||
cuda_key
|
||||
.key
|
||||
@@ -99,7 +100,7 @@ impl FheTryTrivialEncrypt<bool> for FheBool {
|
||||
inner.into_inner(),
|
||||
));
|
||||
(ct, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
});
|
||||
Ok(Self::new(ciphertext, tag))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use super::{FheBool, InnerBoolean};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
@@ -39,8 +41,7 @@ impl FheBool {
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -51,7 +52,7 @@ impl FheBool {
|
||||
)),
|
||||
cuda_key.tag.clone(),
|
||||
)
|
||||
}
|
||||
}),
|
||||
});
|
||||
Self::new(ciphertext, tag)
|
||||
}
|
||||
|
||||
@@ -172,16 +172,17 @@ impl CompressedCiphertextListBuilder {
|
||||
for (element, _) in &self.inner {
|
||||
match element {
|
||||
ToBeCompressed::Cpu(cpu_blocks) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_radixes
|
||||
.push(CudaRadixCiphertext::from_cpu_blocks(cpu_blocks, streams));
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_radixes.push(CudaRadixCiphertext::from_cpu_blocks(
|
||||
cpu_blocks, streams,
|
||||
));
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
ToBeCompressed::Cuda(cuda_radix) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_radixes.push(cuda_radix.duplicate(streams));
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,11 +195,10 @@ impl CompressedCiphertextListBuilder {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.map(|compression_key| {
|
||||
let packed_list = {
|
||||
let streams = &cuda_key.streams;
|
||||
let packed_list = with_thread_local_cuda_streams(|streams| {
|
||||
compression_key
|
||||
.compress_ciphertexts_into_list(cuda_radixes.as_slice(), streams)
|
||||
};
|
||||
});
|
||||
let info = self.inner.iter().map(|(_, kind)| *kind).collect();
|
||||
|
||||
let compressed_list = CudaCompressedCiphertextList { packed_list, info };
|
||||
@@ -458,12 +458,11 @@ impl CiphertextList for CompressedCiphertextList {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.and_then(|decompression_key| {
|
||||
let mut ct = {
|
||||
let streams = &cuda_key.streams;
|
||||
let mut ct = with_thread_local_cuda_streams(|streams| {
|
||||
self.inner
|
||||
.on_gpu(streams)
|
||||
.get::<T>(index, decompression_key, streams)
|
||||
};
|
||||
});
|
||||
if let Ok(Some(ct_ref)) = &mut ct {
|
||||
ct_ref.tag_mut().set_data(cuda_key.tag.data())
|
||||
}
|
||||
|
||||
@@ -201,15 +201,12 @@ pub(in crate::high_level_api) use gpu::{
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use gpu::CudaGpuChoice;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub struct CustomMultiGpuIndexes(Vec<GpuIndex>);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
mod gpu {
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
|
||||
use super::*;
|
||||
use itertools::Itertools;
|
||||
use std::cell::LazyCell;
|
||||
|
||||
thread_local! {
|
||||
@@ -226,14 +223,14 @@ mod gpu {
|
||||
}
|
||||
|
||||
struct CudaStreamPool {
|
||||
custom: Option<CudaStreams>,
|
||||
multi: LazyCell<CudaStreams>,
|
||||
single: Vec<LazyCell<CudaStreams, Box<dyn Fn() -> CudaStreams>>>,
|
||||
}
|
||||
|
||||
impl CudaStreamPool {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
custom: None,
|
||||
multi: LazyCell::new(CudaStreams::new_multi_gpu),
|
||||
single: (0..get_number_of_gpus())
|
||||
.map(|index| {
|
||||
let ctor =
|
||||
@@ -245,6 +242,29 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool {
|
||||
type Output = CudaStreams;
|
||||
|
||||
fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output {
|
||||
match indexes.len() {
|
||||
0 => panic!("Internal error: Gpu indexes must not be empty"),
|
||||
1 => &self.single[indexes[0].get() as usize],
|
||||
_ => &self.multi,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Index<CudaGpuChoice> for CudaStreamPool {
|
||||
type Output = CudaStreams;
|
||||
|
||||
fn index(&self, choice: CudaGpuChoice) -> &Self::Output {
|
||||
match choice {
|
||||
CudaGpuChoice::Multi => &self.multi,
|
||||
CudaGpuChoice::Single(index) => &self.single[index.get() as usize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes<
|
||||
R,
|
||||
F: for<'a> FnOnce(&'a CudaStreams) -> R,
|
||||
@@ -255,44 +275,15 @@ mod gpu {
|
||||
thread_local! {
|
||||
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
|
||||
}
|
||||
|
||||
if gpu_indexes.len() == 1 {
|
||||
POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize]))
|
||||
} else {
|
||||
POOL.with_borrow_mut(|pool| match &pool.custom {
|
||||
Some(streams) if streams.gpu_indexes != gpu_indexes => {
|
||||
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
|
||||
}
|
||||
None => {
|
||||
pool.custom = Some(CudaStreams::new_multi_gpu_with_indexes(gpu_indexes));
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
|
||||
POOL.with_borrow(|pool| func(pool.custom.as_ref().unwrap()))
|
||||
}
|
||||
POOL.with_borrow(|stream_pool| {
|
||||
let stream = &stream_pool[gpu_indexes];
|
||||
func(stream)
|
||||
})
|
||||
}
|
||||
|
||||
impl Clone for CustomMultiGpuIndexes {
|
||||
fn clone(&self) -> Self {
|
||||
self.0.iter().copied().collect_vec().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomMultiGpuIndexes {
|
||||
pub fn new(indexes: Vec<GpuIndex>) -> Self {
|
||||
Self(indexes)
|
||||
}
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
self.0.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum CudaGpuChoice {
|
||||
Single(GpuIndex),
|
||||
Multi,
|
||||
Custom(CustomMultiGpuIndexes),
|
||||
}
|
||||
|
||||
impl From<GpuIndex> for CudaGpuChoice {
|
||||
@@ -301,24 +292,11 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<GpuIndex>> for CustomMultiGpuIndexes {
|
||||
fn from(value: Vec<GpuIndex>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CustomMultiGpuIndexes> for CudaGpuChoice {
|
||||
fn from(values: CustomMultiGpuIndexes) -> Self {
|
||||
Self::Custom(values)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaGpuChoice {
|
||||
pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams {
|
||||
match self {
|
||||
Self::Single(idx) => CudaStreams::new_single_gpu(idx),
|
||||
Self::Multi => CudaStreams::new_multi_gpu(),
|
||||
Self::Custom(idxs) => CudaStreams::new_multi_gpu_with_indexes(idxs.gpu_indexes()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use super::{FheIntId, FheUint, FheUintId};
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
@@ -36,8 +38,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -48,7 +49,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
);
|
||||
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
/// Generates an encrypted `num_block` blocks unsigned integer
|
||||
@@ -86,8 +87,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let d_ct: CudaUnsignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -98,7 +98,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
streams,
|
||||
);
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -136,8 +136,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let d_ct: CudaSignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -148,7 +147,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
);
|
||||
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -188,8 +187,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
Self::new(ct, key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let d_ct: CudaSignedRadixCiphertext = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -200,7 +198,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
streams,
|
||||
);
|
||||
Self::new(d_ct, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::shortint::PBSParameters;
|
||||
use crate::{Device, FheBool, ServerKey, Tag};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
pub trait FheIntId: IntegerId {}
|
||||
|
||||
/// A Generic FHE signed integer
|
||||
@@ -195,14 +197,13 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.abs(&*self.ciphertext.on_gpu(streams), streams);
|
||||
Self::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -232,14 +233,13 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_even(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -269,14 +269,13 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -310,8 +309,7 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -322,7 +320,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -356,8 +354,7 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -368,7 +365,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -402,8 +399,7 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -414,7 +410,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -448,8 +444,7 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -460,7 +455,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -570,8 +565,7 @@ where
|
||||
crate::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -582,7 +576,7 @@ where
|
||||
streams,
|
||||
);
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -625,8 +619,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, is_ok) = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -640,7 +633,7 @@ where
|
||||
crate::FheUint32::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(is_ok, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -814,8 +807,7 @@ where
|
||||
Self::new(new_ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus());
|
||||
let new_ciphertext = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
@@ -823,7 +815,7 @@ where
|
||||
streams,
|
||||
);
|
||||
Self::new(new_ciphertext, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -860,15 +852,14 @@ where
|
||||
Self::new(new_ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let new_ciphertext = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(new_ciphertext, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -908,15 +899,14 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.cast_to_signed(
|
||||
input.ciphertext.into_gpu(streams).0,
|
||||
Id::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::core_crypto::prelude::SignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger};
|
||||
@@ -111,15 +113,14 @@ where
|
||||
Ok(Self::new(ciphertext, key.tag.clone()))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
value,
|
||||
Id::num_blocks(cuda_key.key.key.message_modulus),
|
||||
streams,
|
||||
);
|
||||
Ok(Self::new(inner, cuda_key.tag.clone()))
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ use std::ops::{
|
||||
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
|
||||
impl<'a, Id> std::iter::Sum<&'a Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
@@ -71,8 +74,7 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let cts = iter
|
||||
.map(|fhe_uint| {
|
||||
match fhe_uint.ciphertext.on_gpu(streams) {
|
||||
@@ -102,7 +104,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -142,15 +144,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -189,15 +190,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -247,15 +247,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -287,15 +286,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -353,15 +351,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.lt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -393,15 +390,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.le(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -433,15 +429,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.gt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -473,15 +468,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.ge(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -556,8 +550,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (q, r) = cuda_key.key.key.div_rem(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
@@ -567,7 +560,7 @@ where
|
||||
FheInt::<Id>::new(q, cuda_key.tag.clone()),
|
||||
FheInt::<Id>::new(r, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -641,11 +634,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -684,11 +677,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -727,11 +720,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -768,11 +761,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -809,11 +802,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -850,11 +843,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -898,14 +891,14 @@ generic_integer_impl_operation!(
|
||||
FheInt::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -949,14 +942,14 @@ generic_integer_impl_operation!(
|
||||
FheInt::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -1064,11 +1057,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1108,11 +1101,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1152,11 +1145,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1196,11 +1189,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1246,12 +1239,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1292,12 +1286,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1338,12 +1333,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1382,12 +1378,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1426,12 +1423,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1470,12 +1468,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1519,8 +1518,7 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().div(
|
||||
&*cuda_lhs,
|
||||
@@ -1528,7 +1526,7 @@ where
|
||||
streams,
|
||||
);
|
||||
*cuda_lhs = cuda_result;
|
||||
};
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1572,13 +1570,15 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result =
|
||||
cuda_key
|
||||
.pbs_key()
|
||||
.rem(&*cuda_lhs, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = self.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().rem(
|
||||
&*cuda_lhs,
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
*cuda_lhs = cuda_result;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1627,12 +1627,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1680,12 +1681,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1734,12 +1736,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1788,12 +1791,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1859,14 +1863,13 @@ where
|
||||
FheInt::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.neg(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1929,14 +1932,13 @@ where
|
||||
FheInt::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheInt::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1954,14 +1956,13 @@ where
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
@@ -1974,13 +1975,12 @@ where
|
||||
{
|
||||
fn get_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
}
|
||||
}),
|
||||
InternalServerKey::Cpu(_) => 0,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::core_crypto::prelude::SignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
@@ -51,8 +53,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -62,7 +63,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -148,8 +149,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -159,7 +159,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -283,8 +283,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -294,7 +293,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -379,8 +378,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -390,7 +388,7 @@ where
|
||||
FheInt::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(overflow, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
use crate::core_crypto::commons::numeric::CastFrom;
|
||||
use crate::high_level_api::errors::UnwrapResultExt;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
@@ -53,13 +55,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.scalar_max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
rhs,
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -100,13 +103,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.scalar_min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
rhs,
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -146,13 +150,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -186,13 +191,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -231,13 +237,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -270,13 +277,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -309,13 +317,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -348,13 +357,14 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -398,11 +408,11 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = {let streams = &cuda_key.streams;
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.signed_scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
let (q, r) = (
|
||||
SignedRadixCiphertext::Cuda(inner_q),
|
||||
SignedRadixCiphertext::Cuda(inner_r),
|
||||
@@ -446,11 +456,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_left_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -475,11 +485,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_right_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -504,11 +514,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_rotate_left(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -533,11 +543,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_rotate_right(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -560,10 +570,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
.scalar_left_shift_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -586,9 +597,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -610,10 +622,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
.scalar_rotate_left_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -636,10 +649,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -749,11 +762,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_add(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -778,11 +791,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_sub(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -807,11 +820,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_mul(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -837,11 +850,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitand(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -866,11 +879,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -896,11 +909,11 @@ macro_rules! define_scalar_ops {
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitxor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -925,11 +938,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.signed_scalar_div(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -955,11 +968,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.signed_scalar_rem(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
SignedRadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1001,11 +1014,12 @@ macro_rules! define_scalar_ops {
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result: CudaSignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
SignedRadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1096,9 +1110,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1126,9 +1141,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1152,9 +1168,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1179,9 +1196,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1205,9 +1223,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1230,9 +1249,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1254,11 +1274,11 @@ macro_rules! define_scalar_ops {
|
||||
.signed_scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().signed_scalar_div(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -1279,11 +1299,11 @@ macro_rules! define_scalar_ops {
|
||||
.signed_scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().signed_scalar_rem(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
@@ -4,6 +4,8 @@ use super::inner::RadixCiphertext;
|
||||
use crate::backward_compatibility::integers::FheUintVersions;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::signed::{FheInt, FheIntId};
|
||||
use crate::high_level_api::integers::IntegerId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
@@ -247,14 +249,13 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_even(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -284,14 +285,13 @@ where
|
||||
FheBool::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.is_odd(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheBool::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -418,8 +418,7 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -430,7 +429,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -464,8 +463,7 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -476,7 +474,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -510,8 +508,7 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -522,7 +519,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -556,8 +553,7 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -568,7 +564,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -678,8 +674,7 @@ where
|
||||
super::FheUint32::new(result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -690,7 +685,7 @@ where
|
||||
streams,
|
||||
);
|
||||
super::FheUint32::new(result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -733,8 +728,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, is_ok) = cuda_key
|
||||
.key
|
||||
.key
|
||||
@@ -748,7 +742,7 @@ where
|
||||
super::FheUint32::new(result, cuda_key.tag.clone()),
|
||||
FheBool::new(is_ok, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -816,8 +810,7 @@ where
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, matched) = cuda_key.key.key.match_value(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
matches,
|
||||
@@ -832,7 +825,7 @@ where
|
||||
} else {
|
||||
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -895,8 +888,7 @@ where
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let result = cuda_key.key.key.match_value_or(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
matches,
|
||||
@@ -909,7 +901,7 @@ where
|
||||
} else {
|
||||
Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string()))
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1109,15 +1101,14 @@ where
|
||||
Self::new(casted, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let casted = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(casted, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1154,15 +1145,14 @@ where
|
||||
Self::new(casted, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let casted = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams),
|
||||
IntoId::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(casted, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1199,15 +1189,14 @@ where
|
||||
Self::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner = cuda_key.key.key.cast_to_unsigned(
|
||||
input.ciphertext.into_gpu(streams).0,
|
||||
Id::num_blocks(cuda_key.message_modulus()),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::core_crypto::prelude::UnsignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
|
||||
@@ -113,15 +115,14 @@ where
|
||||
Ok(Self::new(ciphertext, key.tag.clone()))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
value,
|
||||
Id::num_blocks(cuda_key.key.key.message_modulus),
|
||||
streams,
|
||||
);
|
||||
Ok(Self::new(inner, cuda_key.tag.clone()))
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ use super::inner::RadixCiphertext;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::details::MaybeCloned;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
@@ -73,8 +75,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let cts = iter
|
||||
.map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -91,7 +92,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,8 +154,7 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let cts = iter
|
||||
.map(|fhe_uint| {
|
||||
match fhe_uint.ciphertext.on_gpu(streams) {
|
||||
@@ -184,7 +184,7 @@ where
|
||||
)
|
||||
});
|
||||
Self::new(inner, cuda_key.tag.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -224,15 +224,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.max(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -271,15 +270,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.min(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -329,15 +327,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.eq(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -369,15 +366,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.ne(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -435,15 +431,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.lt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -475,15 +470,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.le(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -515,15 +509,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.gt(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -555,15 +548,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.ge(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -639,8 +631,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.div_rem(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
@@ -650,7 +641,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheUint::<Id>::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -726,10 +717,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -768,10 +760,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -810,10 +803,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -850,10 +844,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -890,10 +885,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -930,10 +926,11 @@ generic_integer_impl_operation!(
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -977,15 +974,14 @@ generic_integer_impl_operation!(
|
||||
FheUint::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -1029,16 +1025,14 @@ generic_integer_impl_operation!(
|
||||
FheUint::new(inner_result, cpu_key.tag.clone())
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) =>
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -1146,10 +1140,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1189,10 +1184,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1232,10 +1228,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1275,10 +1272,11 @@ generic_integer_impl_shift_rotate!(
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key
|
||||
.rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1323,14 +1321,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.add_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1369,14 +1366,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.sub_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1415,14 +1411,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.mul_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1459,14 +1454,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitand_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1503,14 +1497,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1547,14 +1540,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.bitxor_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1596,14 +1588,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.div_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1645,14 +1636,13 @@ where
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rem_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1700,14 +1690,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.left_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1755,14 +1744,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.right_shift_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1811,14 +1799,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_left_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1867,12 +1854,13 @@ where
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.rotate_right_assign(
|
||||
self.ciphertext.as_gpu_mut(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1946,14 +1934,13 @@ where
|
||||
FheUint::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.neg(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2016,14 +2003,13 @@ where
|
||||
FheUint::new(ciphertext, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key
|
||||
.key
|
||||
.key
|
||||
.bitnot(&*self.ciphertext.on_gpu(streams), streams);
|
||||
FheUint::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2040,14 +2026,13 @@ where
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_add_assign_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
@@ -2059,13 +2044,12 @@ where
|
||||
{
|
||||
fn get_size_on_gpu(&self) -> u64 {
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_ciphertext_size_on_gpu(&*self.ciphertext.on_gpu(streams))
|
||||
}
|
||||
}),
|
||||
InternalServerKey::Cpu(_) => 0,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::core_crypto::prelude::UnsignedNumeric;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
@@ -51,8 +53,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -62,7 +63,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -148,8 +149,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
other,
|
||||
@@ -159,7 +159,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -285,8 +285,7 @@ where
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result = cuda_key.key.key.unsigned_overflowing_sub(
|
||||
&self.ciphertext.on_gpu(streams),
|
||||
&other.ciphertext.on_gpu(streams),
|
||||
@@ -296,7 +295,7 @@ where
|
||||
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
|
||||
FheBool::new(inner_result.1, cuda_key.tag.clone()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ use super::inner::RadixCiphertext;
|
||||
use crate::error::InvalidRangeError;
|
||||
use crate::high_level_api::errors::UnwrapResultExt;
|
||||
use crate::high_level_api::global_state;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::traits::{
|
||||
@@ -58,15 +60,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,15 +99,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -144,15 +144,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -184,15 +183,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -224,15 +222,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,15 +261,14 @@ where
|
||||
FheBool::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
FheBool::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -312,15 +308,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -360,15 +355,14 @@ where
|
||||
Self::new(inner_result, cpu_key.tag.clone())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let inner_result =
|
||||
cuda_key
|
||||
.key
|
||||
.key
|
||||
.scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams);
|
||||
Self::new(inner_result, cuda_key.tag.clone())
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -499,11 +493,11 @@ macro_rules! generic_integer_impl_scalar_div_rem {
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let (inner_q, inner_r) = {let streams = &cuda_key.streams;
|
||||
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_div_rem(
|
||||
&*self.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
|
||||
(
|
||||
<$concrete_type>::new(q, cuda_key.tag.clone()),
|
||||
@@ -679,11 +673,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_left_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -708,11 +702,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_right_shift(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -737,11 +731,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_rotate_left(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -766,11 +760,11 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_rotate_right(
|
||||
&*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -794,10 +788,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -820,10 +814,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
{let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -846,9 +840,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -871,9 +866,10 @@ macro_rules! define_scalar_rotate_shifts {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -984,11 +980,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_add(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1013,11 +1009,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_sub(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1042,11 +1038,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_mul(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1072,11 +1068,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitand(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1101,11 +1097,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1131,11 +1127,11 @@ macro_rules! define_scalar_ops {
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_bitxor(
|
||||
&*lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1160,11 +1156,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_div(
|
||||
&lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1190,11 +1186,11 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = {let streams = &cuda_key.streams;
|
||||
let inner_result = with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key.scalar_rem(
|
||||
&lhs.ciphertext.on_gpu(streams), rhs, streams
|
||||
)
|
||||
};
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
}
|
||||
})
|
||||
@@ -1235,11 +1231,12 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
RadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1330,9 +1327,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1360,9 +1358,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1386,9 +1385,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1413,9 +1413,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1439,9 +1440,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1464,9 +1466,10 @@ macro_rules! define_scalar_ops {
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
cuda_key.key.key
|
||||
.scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams);
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1488,11 +1491,11 @@ macro_rules! define_scalar_ops {
|
||||
.scalar_div_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().scalar_div(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
@@ -1513,11 +1516,11 @@ macro_rules! define_scalar_ops {
|
||||
.scalar_rem_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs);
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {let streams = &cuda_key.streams;
|
||||
InternalServerKey::Cuda(cuda_key) => global_state::with_thread_local_cuda_streams(|streams| {
|
||||
let cuda_lhs = lhs.ciphertext.as_gpu_mut(streams);
|
||||
let cuda_result = cuda_key.pbs_key().scalar_rem(&cuda_lhs, rhs, streams);
|
||||
*cuda_lhs = cuda_result;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
@@ -328,7 +328,6 @@ impl CompressedServerKey {
|
||||
decompression_key,
|
||||
}),
|
||||
tag: self.tag.clone(),
|
||||
streams,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -352,7 +351,6 @@ impl Named for CompressedServerKey {
|
||||
pub struct CudaServerKey {
|
||||
pub(crate) key: Arc<IntegerCudaServerKey>,
|
||||
pub(crate) tag: Tag,
|
||||
pub(crate) streams: CudaStreams,
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use rand::Rng;
|
||||
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
|
||||
use crate::prelude::*;
|
||||
use crate::{
|
||||
set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex,
|
||||
@@ -118,68 +117,3 @@ fn test_gpu_selection_2() {
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_gpu_selection() {
|
||||
let config = ConfigBuilder::default().build();
|
||||
let keys = ClientKey::generate(config);
|
||||
let compressed_server_keys = CompressedServerKey::new(&keys);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let total_gpus = get_number_of_gpus() as usize;
|
||||
let num_gpus_to_use = rng.gen_range(1..=get_number_of_gpus()) as usize;
|
||||
|
||||
// Randomly sample num_gpus_to_use indices
|
||||
let selected_indices = rand::seq::index::sample(&mut rng, total_gpus, num_gpus_to_use);
|
||||
|
||||
// Convert the selected indices to GpuIndex objects
|
||||
let gpus_to_be_used = CustomMultiGpuIndexes::new(
|
||||
selected_indices
|
||||
.iter()
|
||||
.map(|idx| GpuIndex::new(idx as u32))
|
||||
.collect(),
|
||||
);
|
||||
|
||||
let clear_a: u32 = rng.gen();
|
||||
let clear_b: u32 = rng.gen();
|
||||
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
|
||||
|
||||
assert_eq!(a.current_device(), Device::Cpu);
|
||||
assert_eq!(b.current_device(), Device::Cpu);
|
||||
assert_eq!(a.gpu_indexes(), &[]);
|
||||
assert_eq!(b.gpu_indexes(), &[]);
|
||||
|
||||
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used);
|
||||
|
||||
let first_gpu = GpuIndex::new(0);
|
||||
|
||||
set_server_key(cuda_key);
|
||||
let c = &a + &b;
|
||||
let decrypted: u32 = c.decrypt(&keys);
|
||||
assert_eq!(c.current_device(), Device::CudaGpu);
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
|
||||
// Check explicit move, but first make sure input are on Cpu still
|
||||
assert_eq!(a.current_device(), Device::Cpu);
|
||||
assert_eq!(b.current_device(), Device::Cpu);
|
||||
assert_eq!(a.gpu_indexes(), &[]);
|
||||
assert_eq!(b.gpu_indexes(), &[]);
|
||||
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
|
||||
assert_eq!(a.current_device(), Device::CudaGpu);
|
||||
assert_eq!(b.current_device(), Device::CudaGpu);
|
||||
assert_eq!(a.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(b.gpu_indexes(), &[first_gpu]);
|
||||
|
||||
let c = &a + &b;
|
||||
let decrypted: u32 = c.decrypt(&keys);
|
||||
assert_eq!(c.current_device(), Device::CudaGpu);
|
||||
assert_eq!(c.gpu_indexes(), &[first_gpu]);
|
||||
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user