feat(gpu): add necessary entry points for 128 bit compression

This commit is contained in:
Agnes Leroy
2025-03-19 11:15:21 +01:00
committed by Agnès Leroy
parent d9a3bd438f
commit 7e3a5fd55b
15 changed files with 316 additions and 50 deletions

View File

@@ -78,7 +78,7 @@ fn main() {
"cuda/include/integer/compression/compression.h",
"cuda/include/integer/integer.h",
"cuda/include/zk/zk.h",
"cuda/include/keyswitch.h",
"cuda/include/keyswitch/keyswitch.h",
"cuda/include/keyswitch/ks_enums.h",
"cuda/include/linear_algebra.h",
"cuda/include/fft/fft128.h",

View File

@@ -31,6 +31,11 @@ void cuda_improve_noise_modulus_switch_64(
void const *lwe_array_in, void const *encrypted_zeros, uint32_t lwe_size,
uint32_t num_lwes, uint32_t num_zeros, double input_variance,
double r_sigma, double bound, uint32_t log_modulus);
void cuda_glwe_sample_extract_128(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size);
}
#endif

View File

@@ -4,7 +4,7 @@
#include "integer.h"
#include "integer/radix_ciphertext.cuh"
#include "integer/radix_ciphertext.h"
#include "keyswitch.h"
#include "keyswitch/keyswitch.h"
#include "pbs/programmable_bootstrap.cuh"
#include <cmath>
#include <functional>

View File

@@ -31,6 +31,18 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes);
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t num_lwes, bool allocate_gpu_memory);
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
void *stream, uint32_t gpu_index, void *glwe_array_out,
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes);
void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
uint32_t gpu_index,
int8_t **fp_ks_buffer,

View File

@@ -96,3 +96,45 @@ void cuda_improve_noise_modulus_switch_64(
static_cast<const uint64_t *>(encrypted_zeros), lwe_size, num_lwes,
num_zeros, input_variance, r_sigma, bound, log_modulus);
}
void cuda_glwe_sample_extract_128(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size) {
switch (polynomial_size) {
case 256:
host_sample_extract<__uint128_t, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index,
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
break;
case 512:
host_sample_extract<__uint128_t, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index,
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
break;
case 1024:
host_sample_extract<__uint128_t, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index,
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
break;
case 2048:
host_sample_extract<__uint128_t, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index,
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
break;
case 4096:
host_sample_extract<__uint128_t, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index,
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
break;
default:
PANIC("Cuda error: unsupported polynomial size. Supported "
"N's are powers of two in the interval [256..4096].")
}
}

View File

@@ -1,8 +1,6 @@
#include "fast_packing_keyswitch.cuh"
#include "keyswitch.cuh"
#include "keyswitch.h"
#include <cstdint>
#include <stdio.h>
#include "keyswitch/keyswitch.h"
#include "packing_keyswitch.cuh"
/* Perform keyswitch on a batch of 32 bits input LWE ciphertexts.
* Head out to the equivalent operation on 64 bits for more details.
@@ -73,7 +71,7 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
@@ -90,3 +88,31 @@ void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
static_cast<cudaStream_t>(stream),
gpu_index, gpu_memory_allocated);
}
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t num_lwes, bool allocate_gpu_memory) {
scratch_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
static_cast<cudaStream_t>(stream), gpu_index, fp_ks_buffer, lwe_dimension,
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
}
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
* ciphertexts.
*/
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
void *stream, uint32_t gpu_index, void *glwe_array_out,
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {
host_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<__uint128_t *>(glwe_array_out),
static_cast<const __uint128_t *>(lwe_array_in),
static_cast<const __uint128_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
}

View File

@@ -28,7 +28,7 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
template <typename Torus, typename TorusVec>
template <typename Torus>
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
@@ -63,7 +63,7 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// Continue decomposiion of an array of Torus elements in place. Supposes
// that the array contains already decomposed elements and
// computes the new decomposed level in place.
template <typename Torus, typename TorusVec>
template <typename Torus>
__global__ void
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
@@ -101,7 +101,7 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
// This code is adapted by generalizing the 1d block-tiling
// kernel from https://github.com/siboehm/SGEMM_CUDA
// to any matrix dimension
template <typename Torus, typename TorusVec>
template <typename Torus>
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
int stride_B, Torus *C) {
@@ -251,8 +251,8 @@ __global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C(
degree, coeffIdx, polynomial_size, 1, true);
}
template <typename Torus, typename TorusVec>
__host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
template <typename Torus>
__host__ void host_packing_keyswitch_lwe_list_to_glwe(
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -296,10 +296,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
// decompose first level
decompose_vectorize_init<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(lwe_array_in, d_mem_0,
lwe_dimension, num_lwes,
base_log, level_count);
decompose_vectorize_init<Torus><<<grid_decomp, threads_decomp, 0, stream>>>(
lwe_array_in, d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
// gemm to ks the individual LWEs to GLWEs
@@ -310,7 +308,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
@@ -318,15 +316,14 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
auto ksk_block_size = glwe_accumulator_size;
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
decompose_vectorize_step_inplace<Torus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
tgemm<Torus, TorusVec>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}

View File

@@ -2,8 +2,8 @@
#define CUDA_INTEGER_COMPRESSION_CUH
#include "ciphertext.h"
#include "crypto/fast_packing_keyswitch.cuh"
#include "crypto/keyswitch.cuh"
#include "crypto/packing_keyswitch.cuh"
#include "device.h"
#include "integer/compression/compression.h"
#include "integer/compression/compression_utilities.h"
@@ -116,7 +116,7 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
while (rem_lwes > 0) {
auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe);
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
host_packing_keyswitch_lwe_list_to_glwe<Torus>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,

View File

@@ -1,10 +1,10 @@
#ifndef SETUP_AND_TEARDOWN_H
#define SETUP_AND_TEARDOWN_H
#include "keyswitch/keyswitch.h"
#include "pbs/programmable_bootstrap.h"
#include "pbs/programmable_bootstrap_multibit.h"
#include <device.h>
#include <keyswitch.h>
#include <utils.h>
void programmable_bootstrap_classical_setup(

View File

@@ -60,6 +60,19 @@ unsafe extern "C" {
log_modulus: u32,
);
}
unsafe extern "C" {
pub fn cuda_glwe_sample_extract_128(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
glwe_array_in: *const ffi::c_void,
nth_array: *const u32,
num_nths: u32,
lwe_per_glwe: u32,
glwe_dimension: u32,
polynomial_size: u32,
);
}
pub const PBS_TYPE_MULTI_BIT: PBS_TYPE = 0;
pub const PBS_TYPE_CLASSICAL: PBS_TYPE = 1;
pub type PBS_TYPE = ffi::c_uint;
@@ -1429,6 +1442,34 @@ unsafe extern "C" {
num_lwes: u32,
);
}
unsafe extern "C" {
pub fn scratch_packing_keyswitch_lwe_list_to_glwe_128(
stream: *mut ffi::c_void,
gpu_index: u32,
fp_ks_buffer: *mut *mut i8,
lwe_dimension: u32,
glwe_dimension: u32,
polynomial_size: u32,
num_lwes: u32,
allocate_gpu_memory: bool,
);
}
unsafe extern "C" {
pub fn cuda_packing_keyswitch_lwe_list_to_glwe_128(
stream: *mut ffi::c_void,
gpu_index: u32,
glwe_array_out: *mut ffi::c_void,
lwe_array_in: *const ffi::c_void,
fp_ksk_array: *const ffi::c_void,
fp_ks_buffer: *mut i8,
input_lwe_dimension: u32,
output_glwe_dimension: u32,
output_polynomial_size: u32,
base_log: u32,
level_count: u32,
num_lwes: u32,
);
}
unsafe extern "C" {
pub fn cleanup_packing_keyswitch_lwe_list_to_glwe(
stream: *mut ffi::c_void,

View File

@@ -3,7 +3,7 @@
#include "cuda/include/integer/compression/compression.h"
#include "cuda/include/integer/integer.h"
#include "cuda/include/zk/zk.h"
#include "cuda/include/keyswitch.h"
#include "cuda/include/keyswitch/keyswitch.h"
#include "cuda/include/keyswitch/ks_enums.h"
#include "cuda/include/linear_algebra.h"
#include "cuda/include/fft/fft128.h"

View File

@@ -497,7 +497,7 @@ mod cuda {
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use tfhe::core_crypto::gpu::{
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext,
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64,
get_number_of_gpus, CudaStreams,
};
use tfhe::core_crypto::prelude::*;
@@ -796,7 +796,7 @@ mod cuda {
{
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
gpu_keys.pksk.as_ref().unwrap(),
&d_input_lwe_list,
&mut d_output_glwe,
@@ -879,7 +879,7 @@ mod cuda {
((i, input_lwe_list), output_glwe_list),
local_stream,
)| {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
gpu_keys_vec[i].pksk.as_ref().unwrap(),
input_lwe_list,
output_glwe_list,

View File

@@ -1,14 +1,16 @@
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_packing_keyswitch_key::CudaLwePackingKeyswitchKey;
use crate::core_crypto::gpu::{packing_keyswitch_list_async, CudaStreams};
use crate::core_crypto::gpu::{
packing_keyswitch_list_128_async, packing_keyswitch_list_64_async, CudaStreams,
};
use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scalar>(
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async<Scalar>(
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
@@ -42,7 +44,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scal
lwe_pksk.d_vec.gpu_index(0).get(),
);
packing_keyswitch_list_async(
packing_keyswitch_list_64_async(
streams,
&mut output_glwe_ciphertext.0.d_vec,
&input_lwe_ciphertext_list.0.d_vec,
@@ -56,7 +58,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scal
);
}
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext<Scalar>(
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64<Scalar>(
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
@@ -66,7 +68,78 @@ pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext<Scalar>(
Scalar: UnsignedTorus + CastInto<usize>,
{
unsafe {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
lwe_pksk,
input_lwe_ciphertext_list,
output_glwe_ciphertext,
streams,
);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128_async<Scalar>(
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
streams: &CudaStreams,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
{
let input_lwe_dimension = input_lwe_ciphertext_list.lwe_dimension();
let output_glwe_dimension = output_glwe_ciphertext.glwe_dimension();
let output_polynomial_size = output_glwe_ciphertext.polynomial_size();
assert_eq!(
streams.gpu_indexes[0],
input_lwe_ciphertext_list.0.d_vec.gpu_index(0),
"GPU error: first stream is on GPU {}, first input pointer is on GPU {}",
streams.gpu_indexes[0].get(),
input_lwe_ciphertext_list.0.d_vec.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
output_glwe_ciphertext.0.d_vec.gpu_index(0),
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
output_glwe_ciphertext.0.d_vec.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
lwe_pksk.d_vec.gpu_index(0),
"GPU error: first stream is on GPU {}, first pksk pointer is on GPU {}",
streams.gpu_indexes[0].get(),
lwe_pksk.d_vec.gpu_index(0).get(),
);
packing_keyswitch_list_128_async(
streams,
&mut output_glwe_ciphertext.0.d_vec,
&input_lwe_ciphertext_list.0.d_vec,
input_lwe_dimension,
output_glwe_dimension,
output_polynomial_size,
&lwe_pksk.d_vec,
lwe_pksk.decomposition_base_log(),
lwe_pksk.decomposition_level_count(),
input_lwe_ciphertext_list.lwe_ciphertext_count(),
);
}
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128<Scalar>(
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
streams: &CudaStreams,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
{
unsafe {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128_async(
lwe_pksk,
input_lwe_ciphertext_list,
output_glwe_ciphertext,

View File

@@ -1,5 +1,5 @@
use super::*;
use crate::core_crypto::gpu::algorithms::lwe_packing_keyswitch::cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async;
use crate::core_crypto::gpu::algorithms::lwe_packing_keyswitch::cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async;
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::vec::GpuIndex;
@@ -104,7 +104,7 @@ where
);
unsafe {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
&pksk,
&d_input_lwe,
&mut d_output_glwe,
@@ -197,7 +197,7 @@ where
);
unsafe {
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
&pksk,
&d_input_lwe_list,
&mut d_output_glwe,

View File

@@ -367,7 +367,7 @@ pub unsafe fn convert_lwe_keyswitch_key_async<T: UnsignedInteger>(
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn packing_keyswitch_list_async<T: UnsignedInteger>(
pub unsafe fn packing_keyswitch_list_64_async<T: UnsignedInteger>(
streams: &CudaStreams,
glwe_array_out: &mut CudaVec<T>,
lwe_array_in: &CudaVec<T>,
@@ -412,6 +412,58 @@ pub unsafe fn packing_keyswitch_list_async<T: UnsignedInteger>(
);
}
/// Applies packing keyswitch on a vector of 128-bit LWE ciphertexts
///
/// # Safety
///
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn packing_keyswitch_list_128_async<T: UnsignedInteger>(
streams: &CudaStreams,
glwe_array_out: &mut CudaVec<T>,
lwe_array_in: &CudaVec<T>,
input_lwe_dimension: LweDimension,
output_glwe_dimension: GlweDimension,
output_polynomial_size: PolynomialSize,
packing_keyswitch_key: &CudaVec<T>,
base_log: DecompositionBaseLog,
l_gadget: DecompositionLevelCount,
num_lwes: LweCiphertextCount,
) {
let mut fp_ks_buffer: *mut i8 = std::ptr::null_mut();
scratch_packing_keyswitch_lwe_list_to_glwe_128(
streams.ptr[0],
streams.gpu_indexes[0].get(),
std::ptr::addr_of_mut!(fp_ks_buffer),
input_lwe_dimension.0 as u32,
output_glwe_dimension.0 as u32,
output_polynomial_size.0 as u32,
num_lwes.0 as u32,
true,
);
cuda_packing_keyswitch_lwe_list_to_glwe_128(
streams.ptr[0],
streams.gpu_indexes[0].get(),
glwe_array_out.as_mut_c_ptr(0),
lwe_array_in.as_c_ptr(0),
packing_keyswitch_key.as_c_ptr(0),
fp_ks_buffer,
input_lwe_dimension.0 as u32,
output_glwe_dimension.0 as u32,
output_polynomial_size.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_lwes.0 as u32,
);
cleanup_packing_keyswitch_lwe_list_to_glwe(
streams.ptr[0],
streams.gpu_indexes[0].get(),
std::ptr::addr_of_mut!(fp_ks_buffer),
true,
);
}
/// Convert programmable bootstrap key
///
/// # Safety
@@ -440,7 +492,7 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
l_gadget.0 as u32,
polynomial_size.0 as u32,
);
} else {
} else if size_of::<T>() == 8 {
cuda_convert_lwe_programmable_bootstrap_key_64(
stream_ptr,
streams.gpu_indexes[i].get(),
@@ -451,6 +503,8 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
l_gadget.0 as u32,
polynomial_size.0 as u32,
);
} else {
panic!("Unsupported torus size for bsk conversion")
}
}
}
@@ -504,17 +558,33 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
) {
cuda_glwe_sample_extract_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
glwe_array_in.as_c_ptr(0),
nth_array.as_c_ptr(0).cast::<u32>(),
num_nths,
lwe_per_glwe,
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
);
if size_of::<T>() == 16 {
cuda_glwe_sample_extract_128(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
glwe_array_in.as_c_ptr(0),
nth_array.as_c_ptr(0).cast::<u32>(),
num_nths,
lwe_per_glwe,
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
);
} else if size_of::<T>() == 8 {
cuda_glwe_sample_extract_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
glwe_array_in.as_c_ptr(0),
nth_array.as_c_ptr(0).cast::<u32>(),
num_nths,
lwe_per_glwe,
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
);
} else {
panic!("Unsupported torus size for glwe sample extraction")
}
}
/// # Safety