mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
refactor(gpu): moving cast_to_signed to the backend
This commit is contained in:
committed by
Agnès Leroy
parent
d5a3275a5a
commit
cc33161b97
@@ -127,3 +127,38 @@ template <typename Torus> struct int_cast_to_unsigned_buffer {
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_cast_to_signed_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
uint32_t num_input_blocks;
|
||||
uint32_t target_num_blocks;
|
||||
|
||||
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
|
||||
|
||||
int_cast_to_signed_buffer(CudaStreams streams, int_radix_params params,
|
||||
uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_input_blocks = num_input_blocks;
|
||||
this->target_num_blocks = target_num_blocks;
|
||||
this->extend_buffer = nullptr;
|
||||
|
||||
if (input_is_signed && target_num_blocks > num_input_blocks) {
|
||||
uint32_t num_additional_blocks = target_num_blocks - num_input_blocks;
|
||||
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
|
||||
streams, params, num_input_blocks, num_additional_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
}
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
if (this->extend_buffer) {
|
||||
this->extend_buffer->release(streams);
|
||||
delete this->extend_buffer;
|
||||
}
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -610,25 +610,6 @@ void cuda_integer_unsigned_scalar_div_radix_64(
|
||||
void cleanup_cuda_integer_unsigned_scalar_div_radix_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t num_blocks,
|
||||
uint32_t num_additional_blocks, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
int8_t *mem_ptr,
|
||||
uint32_t num_additional_blocks,
|
||||
void *const *bsks, void *const *ksks);
|
||||
|
||||
void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_integer_signed_scalar_div_radix_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
@@ -1004,6 +985,24 @@ void cuda_unchecked_index_of_clear_64(
|
||||
|
||||
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_cast_to_signed_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool input_is_signed,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input, int8_t *mem,
|
||||
bool input_is_signed, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
} // extern C
|
||||
|
||||
#endif // CUDA_INTEGER_H
|
||||
|
||||
@@ -27,53 +27,6 @@ void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t num_blocks,
|
||||
uint32_t num_additional_blocks, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
ks_level, ks_base_log, pbs_level, pbs_base_log,
|
||||
grouping_factor, message_modulus, carry_modulus,
|
||||
noise_reduction_type);
|
||||
|
||||
return scratch_extend_radix_with_sign_msb<uint64_t>(
|
||||
CudaStreams(streams),
|
||||
(int_extend_radix_with_sign_msb_buffer<uint64_t> **)mem_ptr, params,
|
||||
num_blocks, num_additional_blocks, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
int8_t *mem_ptr,
|
||||
uint32_t num_additional_blocks,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
PUSH_RANGE("cast")
|
||||
host_extend_radix_with_sign_msb<uint64_t>(
|
||||
CudaStreams(streams), output, input,
|
||||
(int_extend_radix_with_sign_msb_buffer<uint64_t> *)mem_ptr,
|
||||
num_additional_blocks, bsks, (uint64_t **)ksks);
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("clean cast")
|
||||
int_extend_radix_with_sign_msb_buffer<uint64_t> *mem_ptr =
|
||||
(int_extend_radix_with_sign_msb_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
POP_RANGE()
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_cast_to_unsigned_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
@@ -116,3 +69,46 @@ void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_cast_to_signed_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool input_is_signed,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
ks_level, ks_base_log, pbs_level, pbs_base_log,
|
||||
grouping_factor, message_modulus, carry_modulus,
|
||||
noise_reduction_type);
|
||||
|
||||
return scratch_cuda_cast_to_signed<uint64_t>(
|
||||
CudaStreams(streams), (int_cast_to_signed_buffer<uint64_t> **)mem_ptr,
|
||||
params, num_input_blocks, target_num_blocks, input_is_signed,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input, int8_t *mem,
|
||||
bool input_is_signed, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_cast_to_signed<uint64_t>(CudaStreams(streams), output, input,
|
||||
(int_cast_to_signed_buffer<uint64_t> *)mem,
|
||||
input_is_signed, bsks, (uint64_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_cast_to_signed_buffer<uint64_t> *mem_ptr =
|
||||
(int_cast_to_signed_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
@@ -160,4 +160,49 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t
|
||||
scratch_cuda_cast_to_signed(CudaStreams streams,
|
||||
int_cast_to_signed_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params, uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_cast_to_signed_buffer<Torus>(
|
||||
streams, params, num_input_blocks, target_num_blocks, input_is_signed,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_cast_to_signed(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
int_cast_to_signed_buffer<Torus> *mem_ptr,
|
||||
bool input_is_signed, void *const *bsks, Torus **ksks) {
|
||||
|
||||
uint32_t current_num_blocks = input->num_radix_blocks;
|
||||
uint32_t target_num_blocks = mem_ptr->target_num_blocks;
|
||||
|
||||
if (input_is_signed) {
|
||||
if (target_num_blocks > current_num_blocks) {
|
||||
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
host_extend_radix_with_sign_msb<Torus>(streams, output, input,
|
||||
mem_ptr->extend_buffer,
|
||||
num_blocks_to_add, bsks, ksks);
|
||||
} else {
|
||||
host_trim_radix_blocks_msb<Torus>(output, input, streams);
|
||||
}
|
||||
} else {
|
||||
if (target_num_blocks > current_num_blocks) {
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
|
||||
streams);
|
||||
} else {
|
||||
host_trim_radix_blocks_msb<Torus>(output, input, streams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1353,44 +1353,6 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_blocks: u32,
|
||||
num_additional_blocks: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_extend_radix_with_sign_msb_64(
|
||||
streams: CudaStreamsFFI,
|
||||
output: *mut CudaRadixCiphertextFFI,
|
||||
input: *const CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
num_additional_blocks: u32,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_extend_radix_with_sign_msb_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_signed_scalar_div_radix_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -2186,6 +2148,42 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_cast_to_signed_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_input_blocks: u32,
|
||||
target_num_blocks: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
input_is_signed: bool,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cast_to_signed_64(
|
||||
streams: CudaStreamsFFI,
|
||||
output: *mut CudaRadixCiphertextFFI,
|
||||
input: *const CudaRadixCiphertextFFI,
|
||||
mem: *mut i8,
|
||||
input_is_signed: bool,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_cast_to_signed_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
|
||||
@@ -5910,82 +5910,6 @@ pub(crate) unsafe fn cuda_backend_unchecked_partial_sum_ciphertexts_assign<
|
||||
update_noise_degree(result, &cuda_ffi_result);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_extend_radix_with_sign_msb<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaRadixCiphertext,
|
||||
ct: &CudaRadixCiphertext,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_additional_blocks: u32,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
let message_modulus = ct.info.blocks.first().unwrap().message_modulus;
|
||||
let carry_modulus = ct.info.blocks.first().unwrap().carry_modulus;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
let mut input_degrees = ct.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut input_noise_levels = ct.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_radix_input =
|
||||
prepare_cuda_radix_ffi(ct, &mut input_degrees, &mut input_noise_levels);
|
||||
|
||||
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let mut cuda_ffi_radix_output =
|
||||
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
|
||||
|
||||
scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
1u32,
|
||||
num_additional_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_extend_radix_with_sign_msb_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_radix_output,
|
||||
&raw const cuda_ffi_radix_input,
|
||||
mem_ptr,
|
||||
num_additional_blocks,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_extend_radix_with_sign_msb_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(output, &cuda_ffi_radix_output);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -10184,3 +10108,84 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear<
|
||||
update_noise_degree(index_ct, &ffi_index);
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_cast_to_signed<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaRadixCiphertext,
|
||||
input: &CudaRadixCiphertext,
|
||||
input_is_signed: bool,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
|
||||
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
|
||||
|
||||
let num_input_blocks = input.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let target_num_blocks = output.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut input_degrees = input.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut input_noise_levels = input.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_input = prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
|
||||
|
||||
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let mut cuda_ffi_output =
|
||||
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_cast_to_signed_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_input_blocks,
|
||||
target_num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
input_is_signed,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_cast_to_signed_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw const cuda_ffi_input,
|
||||
mem_ptr,
|
||||
input_is_signed,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_cast_to_signed_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(output, &cuda_ffi_output);
|
||||
}
|
||||
|
||||
@@ -17,8 +17,7 @@ use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_apply_bivariate_lut, cuda_backend_apply_many_univariate_lut,
|
||||
cuda_backend_apply_univariate_lut, cuda_backend_cast_to_unsigned,
|
||||
cuda_backend_extend_radix_with_sign_msb,
|
||||
cuda_backend_apply_univariate_lut, cuda_backend_cast_to_signed, cuda_backend_cast_to_unsigned,
|
||||
cuda_backend_extend_radix_with_trivial_zero_blocks_msb, cuda_backend_full_propagate_assign,
|
||||
cuda_backend_noise_squashing, cuda_backend_propagate_single_carry_assign,
|
||||
cuda_backend_trim_radix_blocks_lsb, cuda_backend_trim_radix_blocks_msb, CudaServerKey, PBSType,
|
||||
@@ -1094,68 +1093,6 @@ impl CudaServerKey {
|
||||
ciphertexts
|
||||
}
|
||||
|
||||
pub(crate) fn extend_radix_with_sign_msb<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct: &T,
|
||||
num_additional_blocks: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let num_ct_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;
|
||||
let new_num_ct_blocks = num_ct_blocks + num_additional_blocks;
|
||||
|
||||
let mut output: T = self.create_trivial_zero_radix(new_num_ct_blocks, streams);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_extend_radix_with_sign_msb(
|
||||
streams,
|
||||
output.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_additional_blocks as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_extend_radix_with_sign_msb(
|
||||
streams,
|
||||
output.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_additional_blocks as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Cast a [`CudaUnsignedRadixCiphertext`] or a [`CudaSignedRadixCiphertext`]
|
||||
/// to a [`CudaUnsignedRadixCiphertext`] with a possibly different number of blocks
|
||||
///
|
||||
@@ -1313,45 +1250,63 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(&mut source, streams);
|
||||
}
|
||||
|
||||
let current_num_blocks = source.as_ref().info.blocks.len();
|
||||
let mut output_ct: CudaSignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(target_num_blocks, streams);
|
||||
|
||||
if T::IS_SIGNED {
|
||||
// Casting from signed to signed
|
||||
if target_num_blocks > current_num_blocks {
|
||||
let num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
let unsigned_res: T =
|
||||
self.extend_radix_with_sign_msb(&source, num_blocks_to_add, streams);
|
||||
<CudaSignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
unsigned_res.into_inner(),
|
||||
)
|
||||
} else {
|
||||
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
|
||||
let unsigned_res =
|
||||
self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
|
||||
<CudaSignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
unsigned_res.into_inner(),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// casting from unsigned to signed
|
||||
if target_num_blocks > current_num_blocks {
|
||||
let num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
let signed_res = self.extend_radix_with_trivial_zero_blocks_msb(
|
||||
&source,
|
||||
num_blocks_to_add,
|
||||
streams,
|
||||
);
|
||||
<CudaSignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
signed_res.into_inner(),
|
||||
)
|
||||
} else {
|
||||
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
|
||||
let signed_res = self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
|
||||
<CudaSignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
signed_res.into_inner(),
|
||||
)
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_cast_to_signed(
|
||||
streams,
|
||||
output_ct.as_mut(),
|
||||
source.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_cast_to_signed(
|
||||
streams,
|
||||
output_ct.as_mut(),
|
||||
source.as_ref(),
|
||||
T::IS_SIGNED,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_ct
|
||||
}
|
||||
/// Returns the memory space occupied by a radix ciphertext on GPU
|
||||
pub fn get_ciphertext_size_on_gpu<T: CudaIntegerRadixCiphertext>(&self, ct: &T) -> u64 {
|
||||
|
||||
Reference in New Issue
Block a user