mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): ensure single carry propagation returns carry
This commit is contained in:
committed by
Agnès Leroy
parent
dc0d72436d
commit
3e37759f5f
@@ -242,9 +242,12 @@ void scratch_cuda_propagate_single_carry_kb_64_inplace(
|
||||
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory);
|
||||
|
||||
void cuda_propagate_single_carry_kb_64_inplace(
|
||||
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
|
||||
int8_t *mem_ptr, void *bsk, void *ksk, uint32_t num_blocks);
|
||||
void cuda_propagate_single_carry_kb_64_inplace(void **streams,
|
||||
uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count,
|
||||
void *lwe_array, void *carry_out,
|
||||
int8_t *mem_ptr, void *bsk,
|
||||
void *ksk, uint32_t num_blocks);
|
||||
|
||||
void cleanup_cuda_propagate_single_carry(void *stream, uint32_t gpu_index,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
@@ -134,12 +134,15 @@ void scratch_cuda_propagate_single_carry_kb_64_inplace(
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_propagate_single_carry_kb_64_inplace(
|
||||
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *lwe_array,
|
||||
int8_t *mem_ptr, void *bsk, void *ksk, uint32_t num_blocks) {
|
||||
void cuda_propagate_single_carry_kb_64_inplace(void **streams,
|
||||
uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count,
|
||||
void *lwe_array, void *carry_out,
|
||||
int8_t *mem_ptr, void *bsk,
|
||||
void *ksk, uint32_t num_blocks) {
|
||||
host_propagate_single_carry<uint64_t>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<uint64_t *>(lwe_array), static_cast<uint64_t *>(carry_out),
|
||||
(int_sc_prop_memory<uint64_t> *)mem_ptr, bsk,
|
||||
static_cast<uint64_t *>(ksk), num_blocks);
|
||||
}
|
||||
|
||||
@@ -410,6 +410,7 @@ void scratch_cuda_propagate_single_carry_kb_inplace(
|
||||
template <typename Torus>
|
||||
void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *lwe_array,
|
||||
Torus *carry_out,
|
||||
int_sc_prop_memory<Torus> *mem, void *bsk,
|
||||
Torus *ksk, uint32_t num_blocks) {
|
||||
auto params = mem->params;
|
||||
@@ -459,6 +460,10 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count, step_output,
|
||||
generates_or_propagates, 1, num_blocks,
|
||||
big_lwe_size);
|
||||
if (carry_out != nullptr) {
|
||||
cuda_memcpy_async_gpu_to_gpu(carry_out, step_output, big_lwe_size_bytes,
|
||||
streams[0], gpu_indexes[0]);
|
||||
}
|
||||
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
|
||||
@@ -360,8 +360,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
num_blocks);
|
||||
|
||||
host_propagate_single_carry<Torus>(streams, gpu_indexes, gpu_count,
|
||||
radix_lwe_out, mem_ptr->scp_mem, bsk, ksk,
|
||||
num_blocks);
|
||||
radix_lwe_out, nullptr, mem_ptr->scp_mem,
|
||||
bsk, ksk, num_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, class params>
|
||||
|
||||
@@ -1002,6 +1002,7 @@ extern "C" {
|
||||
gpu_indexes: *const u32,
|
||||
gpu_count: u32,
|
||||
radix_lwe: *mut c_void,
|
||||
carry_out: *mut c_void,
|
||||
mem_ptr: *mut i8,
|
||||
bsk: *const c_void,
|
||||
ksk: *const c_void,
|
||||
|
||||
@@ -925,6 +925,7 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
|
||||
pub unsafe fn propagate_single_carry_assign_async<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
radix_lwe_input: &mut CudaVec<T>,
|
||||
carry_out: &mut CudaVec<T>,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
lwe_dimension: LweDimension,
|
||||
@@ -981,6 +982,7 @@ pub unsafe fn propagate_single_carry_assign_async<T: UnsignedInteger, B: Numeric
|
||||
streams.gpu_indexes.as_ptr(),
|
||||
streams.len() as u32,
|
||||
radix_lwe_input.as_mut_c_ptr(),
|
||||
carry_out.as_mut_c_ptr(),
|
||||
mem_ptr,
|
||||
bootstrapping_key.as_c_ptr(),
|
||||
keyswitch_key.as_c_ptr(),
|
||||
|
||||
@@ -104,7 +104,7 @@ impl CudaServerKey {
|
||||
}
|
||||
};
|
||||
self.unchecked_add_assign_async(lhs, rhs, streams);
|
||||
self.propagate_single_carry_assign_async(lhs, streams);
|
||||
let _carry = self.propagate_single_carry_assign_async(lhs, streams);
|
||||
}
|
||||
|
||||
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
|
||||
|
||||
@@ -164,9 +164,11 @@ impl CudaServerKey {
|
||||
&self,
|
||||
ct: &mut T,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut carry_out: T = self.create_trivial_zero_radix(1, streams);
|
||||
let ciphertext = ct.as_mut();
|
||||
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
match &self.bootstrapping_key {
|
||||
@@ -174,6 +176,7 @@ impl CudaServerKey {
|
||||
propagate_single_carry_assign_async(
|
||||
streams,
|
||||
&mut ciphertext.d_blocks.0.d_vec,
|
||||
&mut carry_out.as_mut().d_blocks.0.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
@@ -194,6 +197,7 @@ impl CudaServerKey {
|
||||
propagate_single_carry_assign_async(
|
||||
streams,
|
||||
&mut ciphertext.d_blocks.0.d_vec,
|
||||
&mut carry_out.as_mut().d_blocks.0.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
@@ -215,6 +219,11 @@ impl CudaServerKey {
|
||||
b.degree = Degree::new(b.message_modulus.0 - 1);
|
||||
b.noise_level = NoiseLevel::NOMINAL;
|
||||
});
|
||||
carry_out.as_mut().info.blocks.iter_mut().for_each(|b| {
|
||||
b.degree = Degree::new(1);
|
||||
b.noise_level = NoiseLevel::NOMINAL;
|
||||
});
|
||||
carry_out
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
|
||||
@@ -167,7 +167,7 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
self.unchecked_neg_assign_async(ct, stream);
|
||||
self.propagate_single_carry_assign_async(ct, stream);
|
||||
let _carry = self.propagate_single_carry_assign_async(ct, stream);
|
||||
}
|
||||
|
||||
pub fn neg_assign<T: CudaIntegerRadixCiphertext>(&self, ctxt: &mut T, stream: &CudaStreams) {
|
||||
|
||||
@@ -178,7 +178,7 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
|
||||
self.propagate_single_carry_assign_async(ct, streams);
|
||||
let _carry = self.propagate_single_carry_assign_async(ct, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
|
||||
@@ -150,7 +150,7 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
self.unchecked_scalar_sub_assign_async(ct, scalar, stream);
|
||||
self.propagate_single_carry_assign_async(ct, stream);
|
||||
let _carry = self.propagate_single_carry_assign_async(ct, stream);
|
||||
}
|
||||
|
||||
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStreams)
|
||||
|
||||
@@ -270,7 +270,7 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
self.unchecked_sub_assign_async(lhs, rhs, streams);
|
||||
self.propagate_single_carry_assign_async(lhs, streams);
|
||||
let _carry = self.propagate_single_carry_assign_async(lhs, streams);
|
||||
}
|
||||
|
||||
pub fn unsigned_overflowing_sub(
|
||||
|
||||
Reference in New Issue
Block a user