mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(gpu): add memory tracking functions for bitops
This commit is contained in:
@@ -4,7 +4,10 @@ use crate::high_level_api::global_state;
|
||||
use crate::high_level_api::integers::{FheIntId, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{AddSizeOnGpu, SizeOnGpu, SubSizeOnGpu};
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
RotateRightAssign,
|
||||
@@ -2154,3 +2157,100 @@ where
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitAndSizeOnGpu<I> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitand_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitOrSizeOnGpu<I> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitor_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitXorSizeOnGpu<I> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitxor_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> BitNotSizeOnGpu for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64 {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ use crate::high_level_api::integers::signed::inner::SignedRadixCiphertext;
|
||||
use crate::high_level_api::integers::FheIntId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{AddSizeOnGpu, SubSizeOnGpu};
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
RotateRightAssign,
|
||||
@@ -1009,6 +1011,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitAndSizeOnGpu(get_bitand_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: BitOr(bitor),
|
||||
implem: {
|
||||
@@ -1042,6 +1070,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitOrSizeOnGpu(get_bitor_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: BitXor(bitxor),
|
||||
implem: {
|
||||
@@ -1076,6 +1130,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitXorSizeOnGpu(get_bitxor_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheInt<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: Div(div),
|
||||
implem: {
|
||||
@@ -1277,6 +1357,31 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature="gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitAndSizeOnGpu(get_bitand_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: BitOr(bitor),
|
||||
@@ -1293,6 +1398,31 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature="gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitOrSizeOnGpu(get_bitor_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: BitXor(bitxor),
|
||||
@@ -1309,6 +1439,31 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature="gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitXorSizeOnGpu(get_bitxor_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheInt<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
// Scalar Assign Ops
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ use crate::high_level_api::integers::signed::tests::{
|
||||
};
|
||||
use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
|
||||
use crate::high_level_api::traits::AddSizeOnGpu;
|
||||
use crate::prelude::{check_valid_cuda_malloc, FheTryEncrypt, SubSizeOnGpu};
|
||||
use crate::prelude::{
|
||||
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheTryEncrypt, SubSizeOnGpu,
|
||||
};
|
||||
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use crate::{FheInt32, GpuIndex};
|
||||
use rand::Rng;
|
||||
@@ -109,3 +112,53 @@ fn test_gpu_get_add_sub_size_on_gpu() {
|
||||
assert_eq!(add_tmp_buffer_size, scalar_add_tmp_buffer_size);
|
||||
assert_eq!(add_tmp_buffer_size, scalar_sub_tmp_buffer_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_bitops_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=i32::MAX);
|
||||
let clear_b = rng.gen_range(1..=i32::MAX);
|
||||
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheInt32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let bitand_tmp_buffer_size = a.get_bitand_size_on_gpu(b);
|
||||
let scalar_bitand_tmp_buffer_size = clear_a.get_bitand_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitand_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitand_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitor_tmp_buffer_size = a.get_bitor_size_on_gpu(b);
|
||||
let scalar_bitor_tmp_buffer_size = clear_a.get_bitor_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitxor_tmp_buffer_size = a.get_bitxor_size_on_gpu(b);
|
||||
let scalar_bitxor_tmp_buffer_size = clear_a.get_bitxor_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitxor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitxor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitnot_tmp_buffer_size = a.get_bitnot_size_on_gpu();
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitnot_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -11,7 +11,10 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{AddSizeOnGpu, SizeOnGpu, SubSizeOnGpu};
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
RotateRightAssign,
|
||||
@@ -2380,3 +2383,99 @@ where
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitAndSizeOnGpu<I> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitand_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitand_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitOrSizeOnGpu<I> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitor_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id, I> BitXorSizeOnGpu<I> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
I: Borrow<Self>,
|
||||
{
|
||||
fn get_bitxor_size_on_gpu(&self, rhs: I) -> u64 {
|
||||
let rhs = rhs.borrow();
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_bitxor_size_on_gpu(
|
||||
&*self.ciphertext.on_gpu(streams),
|
||||
&rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Id> BitNotSizeOnGpu for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64 {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key
|
||||
.key
|
||||
.key
|
||||
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,9 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams;
|
||||
use crate::high_level_api::integers::FheUintId;
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::traits::{AddSizeOnGpu, SubSizeOnGpu};
|
||||
use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
|
||||
};
|
||||
use crate::high_level_api::traits::{
|
||||
BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
|
||||
RotateRightAssign,
|
||||
@@ -1303,6 +1305,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitAndSizeOnGpu(get_bitand_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: BitOr(bitor),
|
||||
implem: {
|
||||
@@ -1336,6 +1364,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitOrSizeOnGpu(get_bitor_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: BitXor(bitxor),
|
||||
implem: {
|
||||
@@ -1370,6 +1424,32 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_operation_size_on_gpu!(
|
||||
rust_trait: BitXorSizeOnGpu(get_bitxor_size_on_gpu),
|
||||
implem: {
|
||||
|lhs: &FheUint<_>, _rhs| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*lhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_operation!(
|
||||
rust_trait: Div(div),
|
||||
implem: {
|
||||
@@ -1570,6 +1650,32 @@ macro_rules! define_scalar_ops {
|
||||
);
|
||||
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitAndSizeOnGpu(get_bitand_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitand_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: BitOr(bitor),
|
||||
implem: {
|
||||
@@ -1585,6 +1691,31 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitOrSizeOnGpu(get_bitor_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: BitXor(bitxor),
|
||||
@@ -1601,6 +1732,31 @@ macro_rules! define_scalar_ops {
|
||||
)*
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
generic_integer_impl_get_scalar_left_operation_size_on_gpu!(
|
||||
rust_trait: BitXorSizeOnGpu(get_bitxor_size_on_gpu),
|
||||
implem: {
|
||||
|_lhs, rhs: &FheUint<_>| {
|
||||
let mut tmp_buffer_size = 0;
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => {
|
||||
tmp_buffer_size = 0;
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
tmp_buffer_size = cuda_key.key.key.get_scalar_bitxor_size_on_gpu(
|
||||
&*rhs.ciphertext.on_gpu(streams),
|
||||
streams,
|
||||
);
|
||||
}),
|
||||
});
|
||||
tmp_buffer_size
|
||||
}
|
||||
},
|
||||
fhe_and_scalar_type:
|
||||
$(
|
||||
($concrete_type, $($scalar_type)*),
|
||||
)*
|
||||
);
|
||||
|
||||
// Scalar Assign Ops
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::high_level_api::traits::AddSizeOnGpu;
|
||||
use crate::prelude::{check_valid_cuda_malloc, FheTryEncrypt, SubSizeOnGpu};
|
||||
use crate::prelude::{
|
||||
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
|
||||
FheTryEncrypt, SubSizeOnGpu,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
};
|
||||
@@ -201,3 +204,53 @@ fn test_gpu_get_add_and_sub_size_on_gpu() {
|
||||
assert_eq!(add_tmp_buffer_size, scalar_add_tmp_buffer_size);
|
||||
assert_eq!(add_tmp_buffer_size, scalar_sub_tmp_buffer_size);
|
||||
}
|
||||
#[test]
|
||||
fn test_gpu_get_bitops_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen_range(1..=u32::MAX);
|
||||
let clear_b = rng.gen_range(1..=u32::MAX);
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
|
||||
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
b.move_to_current_device();
|
||||
|
||||
let a = &a;
|
||||
let b = &b;
|
||||
|
||||
let bitand_tmp_buffer_size = a.get_bitand_size_on_gpu(b);
|
||||
let scalar_bitand_tmp_buffer_size = clear_a.get_bitand_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitand_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitand_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitor_tmp_buffer_size = a.get_bitor_size_on_gpu(b);
|
||||
let scalar_bitor_tmp_buffer_size = clear_a.get_bitor_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitxor_tmp_buffer_size = a.get_bitxor_size_on_gpu(b);
|
||||
let scalar_bitxor_tmp_buffer_size = clear_a.get_bitxor_size_on_gpu(b);
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitxor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
assert!(check_valid_cuda_malloc(
|
||||
scalar_bitxor_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
let bitnot_tmp_buffer_size = a.get_bitnot_size_on_gpu();
|
||||
assert!(check_valid_cuda_malloc(
|
||||
bitnot_tmp_buffer_size,
|
||||
GpuIndex::new(0)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -25,4 +25,7 @@ pub use crate::high_level_api::gpu_utils::*;
|
||||
#[cfg(feature = "strings")]
|
||||
pub use crate::high_level_api::strings::traits::*;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use crate::high_level_api::traits::{AddSizeOnGpu, SizeOnGpu, SubSizeOnGpu};
|
||||
pub use crate::high_level_api::traits::{
|
||||
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
|
||||
SubSizeOnGpu,
|
||||
};
|
||||
|
||||
@@ -211,20 +211,6 @@ pub trait SquashNoise {
|
||||
fn squash_noise(&self) -> crate::Result<Self::Output>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait SizeOnGpu<Rhs = Self> {
|
||||
fn get_size_on_gpu(&self) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait AddSizeOnGpu<Rhs = Self> {
|
||||
fn get_add_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait SubSizeOnGpu<Rhs = Self> {
|
||||
fn get_sub_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
|
||||
/// Trait used to have a generic way of waiting Hw accelerator result
|
||||
pub trait FheWait {
|
||||
fn wait(&self);
|
||||
@@ -248,3 +234,34 @@ where
|
||||
src: HpuHandle<&Self>,
|
||||
) -> HpuHandle<Self>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait SizeOnGpu<Rhs = Self> {
|
||||
fn get_size_on_gpu(&self) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait AddSizeOnGpu<Rhs = Self> {
|
||||
fn get_add_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait SubSizeOnGpu<Rhs = Self> {
|
||||
fn get_sub_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait BitAndSizeOnGpu<Rhs = Self> {
|
||||
fn get_bitand_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait BitOrSizeOnGpu<Rhs = Self> {
|
||||
fn get_bitor_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait BitXorSizeOnGpu<Rhs = Self> {
|
||||
fn get_bitxor_size_on_gpu(&self, amount: Rhs) -> u64;
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub trait BitNotSizeOnGpu {
|
||||
fn get_bitnot_size_on_gpu(&self) -> u64;
|
||||
}
|
||||
|
||||
@@ -934,6 +934,62 @@ pub unsafe fn unchecked_bitop_integer_radix_kb_assign_async<T: UnsignedInteger,
|
||||
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
op: BitOpType,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_bitop_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
false,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_bitop(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -1055,6 +1111,62 @@ pub unsafe fn unchecked_scalar_bitop_integer_radix_kb_assign_async<
|
||||
update_noise_degree(radix_lwe, &cuda_ffi_radix_lwe);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn get_scalar_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
op: BitOpType,
|
||||
num_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
noise_reduction_key: Option<&CudaModulusSwitchNoiseReductionKey>,
|
||||
) -> u64 {
|
||||
let allocate_ms_noise_array = noise_reduction_key.is_some();
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_integer_radix_bitop_kb_64(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
op as u32,
|
||||
true,
|
||||
allocate_ms_noise_array,
|
||||
)
|
||||
};
|
||||
unsafe {
|
||||
cleanup_cuda_integer_bitop(
|
||||
streams.ptr.as_ptr(),
|
||||
streams.gpu_indexes_ptr(),
|
||||
streams.len() as u32,
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
}
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_bitop_integer_radix_kb_size_on_gpu, get_full_propagate_assign_size_on_gpu,
|
||||
unchecked_bitop_integer_radix_kb_assign_async, BitOpType, CudaServerKey, PBSType,
|
||||
};
|
||||
|
||||
@@ -243,6 +244,118 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_bitop_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
op: BitOpType,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_dimension(),
|
||||
ct_right.as_ref().d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count(),
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count()
|
||||
);
|
||||
let full_prop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let actual_full_prop_mem = match (
|
||||
ct_left.block_carries_are_empty(),
|
||||
ct_right.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => 0,
|
||||
(true, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
(false, true) => full_prop_mem,
|
||||
(false, false) => self.get_ciphertext_size_on_gpu(ct_right) + full_prop_mem,
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let bitop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
actual_full_prop_mem.max(bitop_mem)
|
||||
}
|
||||
|
||||
pub fn unchecked_bitand_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
@@ -775,4 +888,77 @@ impl CudaServerKey {
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
pub fn get_bitand_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::And, streams)
|
||||
}
|
||||
pub fn get_bitor_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::Or, streams)
|
||||
}
|
||||
pub fn get_bitxor_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::Or, streams)
|
||||
}
|
||||
|
||||
pub fn get_bitnot_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let bitnot_mem = (lwe_ciphertext_count.0 * size_of::<u64>()) as u64;
|
||||
full_prop_mem.max(bitnot_mem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
get_full_propagate_assign_size_on_gpu, get_scalar_bitop_integer_radix_kb_size_on_gpu,
|
||||
unchecked_scalar_bitop_integer_radix_kb_assign_async, BitOpType, CudaServerKey, PBSType,
|
||||
};
|
||||
|
||||
@@ -299,4 +300,124 @@ impl CudaServerKey {
|
||||
self.scalar_bitxor_assign(&mut result, rhs, streams);
|
||||
result
|
||||
}
|
||||
pub fn get_scalar_bitop_size_on_gpu<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
op: BitOpType,
|
||||
streams: &CudaStreams,
|
||||
) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
let full_prop_mem = if ct.block_carries_are_empty() {
|
||||
0
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_full_propagate_assign_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let clear_blocks_mem = (lwe_ciphertext_count.0 * size_of::<u64>()) as u64;
|
||||
|
||||
let scalar_bitop_mem = match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => get_scalar_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.d_ms_noise_reduction_key.as_ref(),
|
||||
),
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
get_scalar_bitop_integer_radix_kb_size_on_gpu(
|
||||
streams,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
full_prop_mem.max(scalar_bitop_mem + clear_blocks_mem)
|
||||
}
|
||||
|
||||
pub fn get_scalar_bitand_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.get_scalar_bitop_size_on_gpu(ct, BitOpType::ScalarAnd, streams)
|
||||
}
|
||||
pub fn get_scalar_bitor_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.get_scalar_bitop_size_on_gpu(ct, BitOpType::ScalarOr, streams)
|
||||
}
|
||||
pub fn get_scalar_bitxor_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.get_scalar_bitop_size_on_gpu(ct, BitOpType::ScalarXor, streams)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user