Compare commits

...

1 Commits

Author SHA1 Message Date
Beka Barbakadze
1b2eacf0ce feat(gpu): Implement signed division in cuda backend 2024-11-19 13:28:37 +04:00
10 changed files with 430 additions and 104 deletions

View File

@@ -357,16 +357,17 @@ void cleanup_cuda_integer_radix_scalar_mul(void *const *streams,
void scratch_cuda_integer_div_rem_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory);
bool is_signed, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory);
void cuda_integer_div_rem_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *quotient, void *remainder, void const *numerator, void const *divisor,
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
bool is_signed, int8_t *mem_ptr, void *const *bsks, void *const *ksks,
uint32_t num_blocks_in_radix);
void cleanup_cuda_integer_div_rem(void *const *streams,

View File

@@ -2264,7 +2264,7 @@ template <typename Torus> struct int_comparison_buffer {
}
};
template <typename Torus> struct int_div_rem_memory {
template <typename Torus> struct unsigned_int_div_rem_memory {
int_radix_params params;
uint32_t active_gpu_count;
@@ -2501,9 +2501,10 @@ template <typename Torus> struct int_div_rem_memory {
}
}
int_div_rem_memory(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_radix_params params,
uint32_t num_blocks, bool allocate_gpu_memory) {
unsigned_int_div_rem_memory(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
int_radix_params params, uint32_t num_blocks,
bool allocate_gpu_memory) {
active_gpu_count = get_active_gpu_count(2 * num_blocks, gpu_count);
this->params = params;
@@ -3060,4 +3061,174 @@ template <typename Torus> struct int_abs_buffer {
}
};
template <typename Torus> struct int_div_rem_memory {
int_radix_params params;
uint32_t active_gpu_count;
bool is_signed;
// memory objects for other operations
unsigned_int_div_rem_memory<Torus> *unsigned_mem;
int_abs_buffer<Torus> *abs_mem_1;
int_abs_buffer<Torus> *abs_mem_2;
int_sc_prop_memory<Torus> *scp_mem_1;
int_sc_prop_memory<Torus> *scp_mem_2;
int_cmux_buffer<Torus> *cmux_quotient_mem;
int_cmux_buffer<Torus> *cmux_remainder_mem;
// lookup tables
int_radix_lut<Torus> *compare_signed_bits_lut;
// sub streams
cudaStream_t *sub_streams_1;
cudaStream_t *sub_streams_2;
cudaStream_t *sub_streams_3;
// temporary device buffers
Torus *positive_numerator;
Torus *positive_divisor;
Torus *sign_bits_are_different;
Torus *negated_quotient;
Torus *negated_remainder;
int_div_rem_memory(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_radix_params params,
bool is_signed, uint32_t num_blocks,
bool allocate_gpu_memory) {
this->active_gpu_count = get_active_gpu_count(2 * num_blocks, gpu_count);
this->params = params;
this->is_signed = is_signed;
unsigned_mem = new unsigned_int_div_rem_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_blocks,
allocate_gpu_memory);
if (is_signed) {
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
Torus sign_bit_pos = 31 - __builtin_clz(params.message_modulus) - 1;
// init memory objects for other integer operations
abs_mem_1 =
new int_abs_buffer<Torus>(streams, gpu_indexes, gpu_count, params,
num_blocks, allocate_gpu_memory);
abs_mem_2 =
new int_abs_buffer<Torus>(streams, gpu_indexes, gpu_count, params,
num_blocks, allocate_gpu_memory);
scp_mem_1 =
new int_sc_prop_memory<Torus>(streams, gpu_indexes, gpu_count, params,
num_blocks, allocate_gpu_memory);
scp_mem_2 =
new int_sc_prop_memory<Torus>(streams, gpu_indexes, gpu_count, params,
num_blocks, allocate_gpu_memory);
std::function<uint64_t(uint64_t)> quotient_predicate_lut_f =
[](uint64_t x) -> uint64_t { return x == 1; };
std::function<uint64_t(uint64_t)> remainder_predicate_lut_f =
[sign_bit_pos](uint64_t x) -> uint64_t {
return (x >> sign_bit_pos) == 1;
};
cmux_quotient_mem = new int_cmux_buffer<Torus>(
streams, gpu_indexes, gpu_count, quotient_predicate_lut_f, params,
num_blocks, allocate_gpu_memory);
cmux_remainder_mem = new int_cmux_buffer<Torus>(
streams, gpu_indexes, gpu_count, remainder_predicate_lut_f, params,
num_blocks, allocate_gpu_memory);
// init temporary memory buffers
positive_numerator =
(Torus *)cuda_malloc_async(big_lwe_size * num_blocks * sizeof(Torus),
streams[0], gpu_indexes[0]);
positive_divisor =
(Torus *)cuda_malloc_async(big_lwe_size * num_blocks * sizeof(Torus),
streams[0], gpu_indexes[0]);
negated_quotient =
(Torus *)cuda_malloc_async(big_lwe_size * num_blocks * sizeof(Torus),
streams[0], gpu_indexes[0]);
negated_remainder =
(Torus *)cuda_malloc_async(big_lwe_size * num_blocks * sizeof(Torus),
streams[0], gpu_indexes[0]);
// init boolean temporary buffers
sign_bits_are_different = (Torus *)cuda_malloc_async(
big_lwe_size * sizeof(Torus), streams[0], gpu_indexes[0]);
// init sub streams
sub_streams_1 =
(cudaStream_t *)malloc(active_gpu_count * sizeof(cudaStream_t));
sub_streams_2 =
(cudaStream_t *)malloc(active_gpu_count * sizeof(cudaStream_t));
sub_streams_3 =
(cudaStream_t *)malloc(active_gpu_count * sizeof(cudaStream_t));
for (uint j = 0; j < active_gpu_count; j++) {
sub_streams_1[j] = cuda_create_stream(gpu_indexes[j]);
sub_streams_2[j] = cuda_create_stream(gpu_indexes[j]);
sub_streams_3[j] = cuda_create_stream(gpu_indexes[j]);
}
// init lookup tables
// to extract and compare signed bits
auto f_compare_extracted_signed_bits = [sign_bit_pos](Torus x,
Torus y) -> Torus {
Torus x_sign_bit = (x >> sign_bit_pos) & 1;
Torus y_sign_bit = (y >> sign_bit_pos) & 1;
return (Torus)(x_sign_bit != y_sign_bit);
};
compare_signed_bits_lut = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, params, 1, 1, true);
generate_device_accumulator_bivariate<Torus>(
streams[0], gpu_indexes[0],
compare_signed_bits_lut->get_lut(gpu_indexes[0], 0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, f_compare_extracted_signed_bits);
compare_signed_bits_lut->broadcast_lut(streams, gpu_indexes,
gpu_indexes[0]);
}
}
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count) {
unsigned_mem->release(streams, gpu_indexes, gpu_count);
delete unsigned_mem;
if (is_signed) {
// release objects for other integer operations
abs_mem_1->release(streams, gpu_indexes, gpu_count);
abs_mem_2->release(streams, gpu_indexes, gpu_count);
scp_mem_1->release(streams, gpu_indexes, gpu_count);
scp_mem_2->release(streams, gpu_indexes, gpu_count);
cmux_quotient_mem->release(streams, gpu_indexes, gpu_count);
cmux_remainder_mem->release(streams, gpu_indexes, gpu_count);
delete abs_mem_1;
delete abs_mem_2;
delete scp_mem_1;
delete scp_mem_2;
delete cmux_quotient_mem;
delete cmux_remainder_mem;
// release lookup tables
compare_signed_bits_lut->release(streams, gpu_indexes, gpu_count);
delete compare_signed_bits_lut;
// release sub streams
for (uint i = 0; i < active_gpu_count; i++) {
cuda_destroy_stream(sub_streams_1[i], gpu_indexes[i]);
cuda_destroy_stream(sub_streams_2[i], gpu_indexes[i]);
cuda_destroy_stream(sub_streams_3[i], gpu_indexes[i]);
}
free(sub_streams_1);
free(sub_streams_2);
free(sub_streams_3);
// drop temporary buffers
cuda_drop_async(positive_numerator, streams[0], gpu_indexes[0]);
cuda_drop_async(positive_divisor, streams[0], gpu_indexes[0]);
cuda_drop_async(sign_bits_are_different, streams[0], gpu_indexes[0]);
cuda_drop_async(negated_quotient, streams[0], gpu_indexes[0]);
cuda_drop_async(negated_remainder, streams[0], gpu_indexes[0]);
}
}
};
#endif // CUDA_INTEGER_UTILITIES_H

View File

@@ -2,11 +2,12 @@
void scratch_cuda_integer_div_rem_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory) {
bool is_signed, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
@@ -14,7 +15,7 @@ void scratch_cuda_integer_div_rem_radix_ciphertext_kb_64(
message_modulus, carry_modulus);
scratch_cuda_integer_div_rem_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
(cudaStream_t *)(streams), gpu_indexes, gpu_count, is_signed,
(int_div_rem_memory<uint64_t> **)mem_ptr, num_blocks, params,
allocate_gpu_memory);
}
@@ -22,7 +23,7 @@ void scratch_cuda_integer_div_rem_radix_ciphertext_kb_64(
void cuda_integer_div_rem_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *quotient, void *remainder, void const *numerator, void const *divisor,
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
bool is_signed, int8_t *mem_ptr, void *const *bsks, void *const *ksks,
uint32_t num_blocks) {
auto mem = (int_div_rem_memory<uint64_t> *)mem_ptr;
@@ -31,8 +32,8 @@ void cuda_integer_div_rem_radix_ciphertext_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(quotient), static_cast<uint64_t *>(remainder),
static_cast<const uint64_t *>(numerator),
static_cast<const uint64_t *>(divisor), bsks, (uint64_t **)(ksks), mem,
num_blocks);
static_cast<const uint64_t *>(divisor), is_signed, bsks,
(uint64_t **)(ksks), mem, num_blocks);
}
void cleanup_cuda_integer_div_rem(void *const *streams,

View File

@@ -3,6 +3,7 @@
#include "crypto/keyswitch.cuh"
#include "device.h"
#include "integer/abs.cuh"
#include "integer/comparison.cuh"
#include "integer/integer.cuh"
#include "integer/integer_utilities.h"
@@ -161,22 +162,21 @@ template <typename Torus> struct lwe_ciphertext_list {
template <typename Torus>
__host__ void scratch_cuda_integer_div_rem_kb(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_div_rem_memory<Torus> **mem_ptr,
uint32_t gpu_count, bool is_signed, int_div_rem_memory<Torus> **mem_ptr,
uint32_t num_blocks, int_radix_params params, bool allocate_gpu_memory) {
*mem_ptr = new int_div_rem_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_blocks, allocate_gpu_memory);
*mem_ptr =
new int_div_rem_memory<Torus>(streams, gpu_indexes, gpu_count, params,
is_signed, num_blocks, allocate_gpu_memory);
}
template <typename Torus>
__host__ void host_integer_div_rem_kb(cudaStream_t const *streams,
uint32_t const *gpu_indexes,
uint32_t gpu_count, Torus *quotient,
Torus *remainder, Torus const *numerator,
Torus const *divisor, void *const *bsks,
uint64_t *const *ksks,
int_div_rem_memory<uint64_t> *mem_ptr,
uint32_t num_blocks) {
__host__ void host_unsigned_integer_div_rem_kb(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, Torus *quotient, Torus *remainder,
Torus const *numerator, Torus const *divisor, void *const *bsks,
uint64_t *const *ksks, unsigned_int_div_rem_memory<uint64_t> *mem_ptr,
uint32_t num_blocks) {
auto radix_params = mem_ptr->params;
@@ -594,4 +594,105 @@ __host__ void host_integer_div_rem_kb(cudaStream_t const *streams,
}
}
template <typename Torus>
__host__ void host_integer_div_rem_kb(cudaStream_t const *streams,
uint32_t const *gpu_indexes,
uint32_t gpu_count, Torus *quotient,
Torus *remainder, Torus const *numerator,
Torus const *divisor, bool is_signed,
void *const *bsks, uint64_t *const *ksks,
int_div_rem_memory<uint64_t> *int_mem_ptr,
uint32_t num_blocks) {
if (is_signed) {
auto radix_params = int_mem_ptr->params;
uint32_t big_lwe_size = radix_params.big_lwe_dimension + 1;
// temporary memory
lwe_ciphertext_list<Torus> positive_numerator(
int_mem_ptr->positive_numerator, radix_params, num_blocks);
lwe_ciphertext_list<Torus> positive_divisor(int_mem_ptr->positive_divisor,
radix_params, num_blocks);
positive_numerator.clone_from((Torus *)numerator, 0, num_blocks - 1,
streams[0], gpu_indexes[0]);
positive_divisor.clone_from((Torus *)divisor, 0, num_blocks - 1, streams[0],
gpu_indexes[0]);
for (uint j = 0; j < gpu_count; j++) {
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
}
host_integer_abs_kb<Torus>(int_mem_ptr->sub_streams_1, gpu_indexes,
gpu_count, positive_numerator.data, bsks, ksks,
int_mem_ptr->abs_mem_1, true, num_blocks);
host_integer_abs_kb<Torus>(int_mem_ptr->sub_streams_2, gpu_indexes,
gpu_count, positive_divisor.data, bsks, ksks,
int_mem_ptr->abs_mem_2, true, num_blocks);
for (uint j = 0; j < int_mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(int_mem_ptr->sub_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(int_mem_ptr->sub_streams_2[j], gpu_indexes[j]);
}
host_unsigned_integer_div_rem_kb<Torus>(
int_mem_ptr->sub_streams_1, gpu_indexes, gpu_count, quotient, remainder,
positive_numerator.data, positive_divisor.data, bsks, ksks,
int_mem_ptr->unsigned_mem, num_blocks);
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
int_mem_ptr->sub_streams_2, gpu_indexes, gpu_count,
int_mem_ptr->sign_bits_are_different,
&numerator[big_lwe_size * (num_blocks - 1)],
&divisor[big_lwe_size * (num_blocks - 1)], bsks, ksks, 1,
int_mem_ptr->compare_signed_bits_lut,
int_mem_ptr->compare_signed_bits_lut->params.message_modulus);
for (uint j = 0; j < int_mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(int_mem_ptr->sub_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(int_mem_ptr->sub_streams_2[j], gpu_indexes[j]);
}
host_integer_radix_negation(
int_mem_ptr->sub_streams_1, gpu_indexes, gpu_count,
int_mem_ptr->negated_quotient, quotient, radix_params.big_lwe_dimension,
num_blocks, radix_params.message_modulus, radix_params.carry_modulus);
host_propagate_single_carry<Torus>(int_mem_ptr->sub_streams_1, gpu_indexes,
gpu_count, int_mem_ptr->negated_quotient,
nullptr, nullptr, int_mem_ptr->scp_mem_1,
bsks, ksks, num_blocks);
host_integer_radix_negation(int_mem_ptr->sub_streams_2, gpu_indexes,
gpu_count, int_mem_ptr->negated_remainder,
remainder, radix_params.big_lwe_dimension,
num_blocks, radix_params.message_modulus,
radix_params.carry_modulus);
host_propagate_single_carry<Torus>(
int_mem_ptr->sub_streams_2, gpu_indexes, gpu_count,
int_mem_ptr->negated_remainder, nullptr, nullptr,
int_mem_ptr->scp_mem_2, bsks, ksks, num_blocks);
host_integer_radix_cmux_kb<Torus>(
int_mem_ptr->sub_streams_1, gpu_indexes, gpu_count, quotient,
int_mem_ptr->sign_bits_are_different, int_mem_ptr->negated_quotient,
quotient, int_mem_ptr->cmux_quotient_mem, bsks, ksks, num_blocks);
host_integer_radix_cmux_kb<Torus>(
int_mem_ptr->sub_streams_2, gpu_indexes, gpu_count, remainder,
&numerator[big_lwe_size * (num_blocks - 1)],
int_mem_ptr->negated_remainder, remainder,
int_mem_ptr->cmux_remainder_mem, bsks, ksks, num_blocks);
for (uint j = 0; j < int_mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(int_mem_ptr->sub_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(int_mem_ptr->sub_streams_2[j], gpu_indexes[j]);
}
} else {
host_unsigned_integer_div_rem_kb<Torus>(
streams, gpu_indexes, gpu_count, quotient, remainder, numerator,
divisor, bsks, ksks, int_mem_ptr->unsigned_mem, num_blocks);
}
}
#endif // TFHE_RS_DIV_REM_CUH

View File

@@ -896,6 +896,7 @@ extern "C" {
streams: *const *mut ffi::c_void,
gpu_indexes: *const u32,
gpu_count: u32,
is_signed: bool,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
@@ -922,6 +923,7 @@ extern "C" {
remainder: *mut ffi::c_void,
numerator: *const ffi::c_void,
divisor: *const ffi::c_void,
is_signed: bool,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,

View File

@@ -631,8 +631,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_result = cuda_key.key.key.div_rem(
&self.ciphertext.on_gpu(),
&rhs.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(),
&*rhs.ciphertext.on_gpu(),
streams,
);
(
@@ -977,7 +977,7 @@ generic_integer_impl_operation!(
cuda_key
.key
.key
.div(&lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams);
.div(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
})
@@ -1028,7 +1028,7 @@ generic_integer_impl_operation!(
cuda_key
.key
.key
.rem(&lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams);
.rem(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams);
FheUint::new(inner_result, cuda_key.tag.clone())
}),
})

View File

@@ -2514,15 +2514,13 @@ pub unsafe fn apply_bivariate_lut_kb_async<T: UnsignedInteger, B: Numeric>(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_unsigned_div_rem_integer_radix_kb_assign_async<
T: UnsignedInteger,
B: Numeric,
>(
pub unsafe fn unchecked_div_rem_integer_radix_kb_assign_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
quotient: &mut CudaVec<T>,
remainder: &mut CudaVec<T>,
numerator: &CudaVec<T>,
divisor: &CudaVec<T>,
is_signed: bool,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
@@ -2544,6 +2542,7 @@ pub unsafe fn unchecked_unsigned_div_rem_integer_radix_kb_assign_async<
streams.ptr.as_ptr(),
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
is_signed,
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
@@ -2568,6 +2567,7 @@ pub unsafe fn unchecked_unsigned_div_rem_integer_radix_kb_assign_async<
remainder.as_mut_c_ptr(0),
numerator.as_c_ptr(0),
divisor.as_c_ptr(0),
is_signed,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),

View File

@@ -1,32 +1,35 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{unchecked_unsigned_div_rem_integer_radix_kb_assign_async, PBSType};
use crate::integer::gpu::{unchecked_div_rem_integer_radix_kb_assign_async, PBSType};
impl CudaServerKey {
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unsigned_unchecked_div_rem_assign_async(
pub unsafe fn unchecked_div_rem_assign_async<T>(
&self,
quotient: &mut CudaUnsignedRadixCiphertext,
remainder: &mut CudaUnsignedRadixCiphertext,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
quotient: &mut T,
remainder: &mut T,
numerator: &T,
divisor: &T,
streams: &CudaStreams,
) {
// TODO add asserts from `unsigned_unchecked_div_rem_parallelized`
) where
T: CudaIntegerRadixCiphertext,
{
// TODO add asserts from `unchecked_div_rem_parallelized`
let num_blocks = divisor.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_unsigned_div_rem_integer_radix_kb_assign_async(
unchecked_div_rem_integer_radix_kb_assign_async(
streams,
&mut quotient.as_mut().d_blocks.0.d_vec,
&mut remainder.as_mut().d_blocks.0.d_vec,
&numerator.as_ref().d_blocks.0.d_vec,
&divisor.as_ref().d_blocks.0.d_vec,
T::IS_SIGNED,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
@@ -49,12 +52,13 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_unsigned_div_rem_integer_radix_kb_assign_async(
unchecked_div_rem_integer_radix_kb_assign_async(
streams,
&mut quotient.as_mut().d_blocks.0.d_vec,
&mut remainder.as_mut().d_blocks.0.d_vec,
&numerator.as_ref().d_blocks.0.d_vec,
&divisor.as_ref().d_blocks.0.d_vec,
T::IS_SIGNED,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
@@ -82,33 +86,31 @@ impl CudaServerKey {
remainder.as_mut().info = remainder.as_ref().info.after_div_rem();
}
pub fn unsigned_unchecked_div_rem_assign(
pub fn unchecked_div_rem_assign<T>(
&self,
quotient: &mut CudaUnsignedRadixCiphertext,
remainder: &mut CudaUnsignedRadixCiphertext,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
quotient: &mut T,
remainder: &mut T,
numerator: &T,
divisor: &T,
streams: &CudaStreams,
) {
) where
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unsigned_unchecked_div_rem_assign_async(
quotient, remainder, numerator, divisor, streams,
);
self.unchecked_div_rem_assign_async(quotient, remainder, numerator, divisor, streams);
}
streams.synchronize();
}
pub fn unchecked_div_rem(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext) {
pub fn unchecked_div_rem<T>(&self, numerator: &T, divisor: &T, streams: &CudaStreams) -> (T, T)
where
T: CudaIntegerRadixCiphertext,
{
let mut quotient = unsafe { numerator.duplicate_async(streams) };
let mut remainder = unsafe { numerator.duplicate_async(streams) };
unsafe {
self.unsigned_unchecked_div_rem_assign_async(
self.unchecked_div_rem_assign_async(
&mut quotient,
&mut remainder,
numerator,
@@ -120,12 +122,10 @@ impl CudaServerKey {
(quotient, remainder)
}
pub fn div_rem(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext) {
pub fn div_rem<T>(&self, numerator: &T, divisor: &T, streams: &CudaStreams) -> (T, T)
where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_numerator;
let mut tmp_divisor;
@@ -158,14 +158,16 @@ impl CudaServerKey {
self.unchecked_div_rem(numerator, divisor, streams)
}
pub fn div_rem_assign(
pub fn div_rem_assign<T>(
&self,
quotient: &mut CudaUnsignedRadixCiphertext,
remainder: &mut CudaUnsignedRadixCiphertext,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
quotient: &mut T,
remainder: &mut T,
numerator: &T,
divisor: &T,
streams: &CudaStreams,
) {
) where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_numerator;
let mut tmp_divisor;
@@ -196,38 +198,30 @@ impl CudaServerKey {
};
unsafe {
self.unsigned_unchecked_div_rem_assign_async(
quotient, remainder, numerator, divisor, streams,
);
self.unchecked_div_rem_assign_async(quotient, remainder, numerator, divisor, streams);
}
streams.synchronize();
}
pub fn div(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
pub fn div<T>(&self, numerator: &T, divisor: &T, streams: &CudaStreams) -> T
where
T: CudaIntegerRadixCiphertext,
{
let (q, _r) = self.div_rem(numerator, divisor, streams);
q
}
pub fn rem(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
pub fn rem<T>(&self, numerator: &T, divisor: &T, streams: &CudaStreams) -> T
where
T: CudaIntegerRadixCiphertext,
{
let (_q, r) = self.div_rem(numerator, divisor, streams);
r
}
pub fn div_assign(
&self,
numerator: &mut CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) {
pub fn div_assign<T>(&self, numerator: &mut T, divisor: &T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
let mut remainder = numerator.duplicate(streams);
self.div_rem_assign(
numerator,
@@ -238,12 +232,10 @@ impl CudaServerKey {
);
}
pub fn rem_assign(
&self,
numerator: &mut CudaUnsignedRadixCiphertext,
divisor: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) {
pub fn rem_assign<T>(&self, numerator: &mut T, divisor: &T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
let mut quotient = numerator.duplicate(streams);
self.div_rem_assign(
&mut quotient,

View File

@@ -3,6 +3,7 @@ pub(crate) mod test_add;
pub(crate) mod test_bitwise_op;
pub(crate) mod test_cmux;
pub(crate) mod test_comparison;
pub(crate) mod test_div_mod;
pub(crate) mod test_ilog2;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
@@ -523,3 +524,44 @@ where
)
}
}
// for signed div_rem
impl<'a, F>
FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, SignedRadixCiphertext),
> for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_signed_radix_ciphertext(&context.streams),
)
}
}

View File

@@ -0,0 +1,16 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_div_rem::signed_unchecked_div_rem_test;
use crate::shortint::parameters::*;
create_gpu_parametrized_test!(integer_signed_unchecked_div_rem);
fn integer_signed_unchecked_div_rem<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::div_rem);
signed_unchecked_div_rem_test(param, executor);
}