mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(gpu): add a function to set a CudaLweList to 0
This commit is contained in:
@@ -220,4 +220,19 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
|
||||
pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus<T> {
|
||||
self.0.ciphertext_modulus
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn set_to_zero_async(&mut self, streams: &CudaStreams) {
|
||||
self.0.d_vec.memset_async(0u64, streams, 0);
|
||||
}
|
||||
|
||||
pub fn set_to_zero(&mut self, streams: &CudaStreams) {
|
||||
unsafe {
|
||||
self.set_to_zero_async(streams);
|
||||
streams.synchronize_one(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,16 +175,13 @@ impl<T: Numeric> CudaVec<T> {
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn memset_async(&mut self, value: T, streams: &CudaStreams, stream_index: u32)
|
||||
where
|
||||
T: Into<u64>,
|
||||
{
|
||||
pub unsafe fn memset_async(&mut self, value: u64, streams: &CudaStreams, stream_index: u32) {
|
||||
let size = self.len() * std::mem::size_of::<T>();
|
||||
// We check that self is not empty to avoid invalid pointers
|
||||
if size > 0 {
|
||||
cuda_memset_async(
|
||||
self.as_mut_c_ptr(stream_index),
|
||||
value.into(),
|
||||
value,
|
||||
size as u64,
|
||||
streams.ptr[stream_index as usize],
|
||||
streams.gpu_indexes[stream_index as usize].0,
|
||||
|
||||
Reference in New Issue
Block a user