fix(gpu): ensure single carry propagation returns carry

This commit is contained in:
Beka Barbakadze
2024-06-03 10:23:28 +04:00
committed by Agnès Leroy
parent dc0d72436d
commit 3e37759f5f
12 changed files with 38 additions and 15 deletions

View File

@@ -242,9 +242,12 @@ void scratch_cuda_propagate_single_carry_kb_64_inplace(
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory);
void cuda_propagate_single_carry_kb_64_inplace(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
int8_t *mem_ptr, void *bsk, void *ksk, uint32_t num_blocks);
void cuda_propagate_single_carry_kb_64_inplace(void **streams,
uint32_t *gpu_indexes,
uint32_t gpu_count,
void *lwe_array, void *carry_out,
int8_t *mem_ptr, void *bsk,
void *ksk, uint32_t num_blocks);
void cleanup_cuda_propagate_single_carry(void *stream, uint32_t gpu_index,
int8_t **mem_ptr_void);

View File

@@ -134,12 +134,15 @@ void scratch_cuda_propagate_single_carry_kb_64_inplace(
allocate_gpu_memory);
}
void cuda_propagate_single_carry_kb_64_inplace(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
int8_t *mem_ptr, void *bsk, void *ksk, uint32_t num_blocks) {
void cuda_propagate_single_carry_kb_64_inplace(void **streams,
uint32_t *gpu_indexes,
uint32_t gpu_count,
void *lwe_array, void *carry_out,
int8_t *mem_ptr, void *bsk,
void *ksk, uint32_t num_blocks) {
host_propagate_single_carry<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array),
static_cast<uint64_t *>(lwe_array), static_cast<uint64_t *>(carry_out),
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsk,
static_cast<uint64_t *>(ksk), num_blocks);
}

View File

@@ -410,6 +410,7 @@ void scratch_cuda_propagate_single_carry_kb_inplace(
template <typename Torus>
void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
uint32_t gpu_count, Torus *lwe_array,
Torus *carry_out,
int_sc_prop_memory<Torus> *mem, void *bsk,
Torus *ksk, uint32_t num_blocks) {
auto params = mem->params;
@@ -459,6 +460,10 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count, step_output,
generates_or_propagates, 1, num_blocks,
big_lwe_size);
if (carry_out != nullptr) {
cuda_memcpy_async_gpu_to_gpu(carry_out, step_output, big_lwe_size_bytes,
streams[0], gpu_indexes[0]);
}
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
gpu_indexes[0]);

View File

@@ -360,8 +360,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
num_blocks);
host_propagate_single_carry<Torus>(streams, gpu_indexes, gpu_count,
radix_lwe_out, mem_ptr->scp_mem, bsk, ksk,
num_blocks);
radix_lwe_out, nullptr, mem_ptr->scp_mem,
bsk, ksk, num_blocks);
}
template <typename Torus, typename STorus, class params>

View File

@@ -1002,6 +1002,7 @@ extern "C" {
gpu_indexes: *const u32,
gpu_count: u32,
radix_lwe: *mut c_void,
carry_out: *mut c_void,
mem_ptr: *mut i8,
bsk: *const c_void,
ksk: *const c_void,

View File

@@ -925,6 +925,7 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
pub unsafe fn propagate_single_carry_assign_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_input: &mut CudaVec<T>,
carry_out: &mut CudaVec<T>,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
lwe_dimension: LweDimension,
@@ -981,6 +982,7 @@ pub unsafe fn propagate_single_carry_assign_async<T: UnsignedInteger, B: Numeric
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
radix_lwe_input.as_mut_c_ptr(),
carry_out.as_mut_c_ptr(),
mem_ptr,
bootstrapping_key.as_c_ptr(),
keyswitch_key.as_c_ptr(),

View File

@@ -104,7 +104,7 @@ impl CudaServerKey {
}
};
self.unchecked_add_assign_async(lhs, rhs, streams);
self.propagate_single_carry_assign_async(lhs, streams);
let _carry = self.propagate_single_carry_assign_async(lhs, streams);
}
pub fn add_assign<T: CudaIntegerRadixCiphertext>(

View File

@@ -164,9 +164,11 @@ impl CudaServerKey {
&self,
ct: &mut T,
streams: &CudaStreams,
) where
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut carry_out: T = self.create_trivial_zero_radix(1, streams);
let ciphertext = ct.as_mut();
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
match &self.bootstrapping_key {
@@ -174,6 +176,7 @@ impl CudaServerKey {
propagate_single_carry_assign_async(
streams,
&mut ciphertext.d_blocks.0.d_vec,
&mut carry_out.as_mut().d_blocks.0.d_vec,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
@@ -194,6 +197,7 @@ impl CudaServerKey {
propagate_single_carry_assign_async(
streams,
&mut ciphertext.d_blocks.0.d_vec,
&mut carry_out.as_mut().d_blocks.0.d_vec,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
@@ -215,6 +219,11 @@ impl CudaServerKey {
b.degree = Degree::new(b.message_modulus.0 - 1);
b.noise_level = NoiseLevel::NOMINAL;
});
carry_out.as_mut().info.blocks.iter_mut().for_each(|b| {
b.degree = Degree::new(1);
b.noise_level = NoiseLevel::NOMINAL;
});
carry_out
}
/// # Safety

View File

@@ -167,7 +167,7 @@ impl CudaServerKey {
};
self.unchecked_neg_assign_async(ct, stream);
self.propagate_single_carry_assign_async(ct, stream);
let _carry = self.propagate_single_carry_assign_async(ct, stream);
}
pub fn neg_assign<T: CudaIntegerRadixCiphertext>(&self, ctxt: &mut T, stream: &CudaStreams) {

View File

@@ -178,7 +178,7 @@ impl CudaServerKey {
};
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
self.propagate_single_carry_assign_async(ct, streams);
let _carry = self.propagate_single_carry_assign_async(ct, streams);
}
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)

View File

@@ -150,7 +150,7 @@ impl CudaServerKey {
};
self.unchecked_scalar_sub_assign_async(ct, scalar, stream);
self.propagate_single_carry_assign_async(ct, stream);
let _carry = self.propagate_single_carry_assign_async(ct, stream);
}
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStreams)

View File

@@ -270,7 +270,7 @@ impl CudaServerKey {
};
self.unchecked_sub_assign_async(lhs, rhs, streams);
self.propagate_single_carry_assign_async(lhs, streams);
let _carry = self.propagate_single_carry_assign_async(lhs, streams);
}
pub fn unsigned_overflowing_sub(