feat(gpu): add a function to set a CudaLweList to 0

This commit is contained in:
Agnes Leroy
2025-02-28 11:41:27 +01:00
committed by Agnès Leroy
parent 95863e1e36
commit c1bf43eac1
2 changed files with 17 additions and 5 deletions

View File

@@ -220,4 +220,19 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus<T> {
self.0.ciphertext_modulus
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn set_to_zero_async(&mut self, streams: &CudaStreams) {
self.0.d_vec.memset_async(0u64, streams, 0);
}
pub fn set_to_zero(&mut self, streams: &CudaStreams) {
unsafe {
self.set_to_zero_async(streams);
streams.synchronize_one(0);
}
}
}

View File

@@ -175,16 +175,13 @@ impl<T: Numeric> CudaVec<T> {
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn memset_async(&mut self, value: T, streams: &CudaStreams, stream_index: u32)
where
T: Into<u64>,
{
pub unsafe fn memset_async(&mut self, value: u64, streams: &CudaStreams, stream_index: u32) {
let size = self.len() * std::mem::size_of::<T>();
// We check that self is not empty to avoid invalid pointers
if size > 0 {
cuda_memset_async(
self.as_mut_c_ptr(stream_index),
value.into(),
value,
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,