mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): add necessary entry points for 128 bit compression
This commit is contained in:
@@ -78,7 +78,7 @@ fn main() {
|
||||
"cuda/include/integer/compression/compression.h",
|
||||
"cuda/include/integer/integer.h",
|
||||
"cuda/include/zk/zk.h",
|
||||
"cuda/include/keyswitch.h",
|
||||
"cuda/include/keyswitch/keyswitch.h",
|
||||
"cuda/include/keyswitch/ks_enums.h",
|
||||
"cuda/include/linear_algebra.h",
|
||||
"cuda/include/fft/fft128.h",
|
||||
|
||||
@@ -31,6 +31,11 @@ void cuda_improve_noise_modulus_switch_64(
|
||||
void const *lwe_array_in, void const *encrypted_zeros, uint32_t lwe_size,
|
||||
uint32_t num_lwes, uint32_t num_zeros, double input_variance,
|
||||
double r_sigma, double bound, uint32_t log_modulus);
|
||||
|
||||
void cuda_glwe_sample_extract_128(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
|
||||
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "integer.h"
|
||||
#include "integer/radix_ciphertext.cuh"
|
||||
#include "integer/radix_ciphertext.h"
|
||||
#include "keyswitch.h"
|
||||
#include "keyswitch/keyswitch.h"
|
||||
#include "pbs/programmable_bootstrap.cuh"
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
@@ -31,6 +31,18 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_lwes);
|
||||
|
||||
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t num_lwes, bool allocate_gpu_memory);
|
||||
|
||||
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
void *stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
|
||||
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_lwes);
|
||||
|
||||
void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
|
||||
uint32_t gpu_index,
|
||||
int8_t **fp_ks_buffer,
|
||||
@@ -96,3 +96,45 @@ void cuda_improve_noise_modulus_switch_64(
|
||||
static_cast<const uint64_t *>(encrypted_zeros), lwe_size, num_lwes,
|
||||
num_zeros, input_variance, r_sigma, bound, log_modulus);
|
||||
}
|
||||
|
||||
void cuda_glwe_sample_extract_128(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
|
||||
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size) {
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 256:
|
||||
host_sample_extract<__uint128_t, AmortizedDegree<256>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
|
||||
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
|
||||
break;
|
||||
case 512:
|
||||
host_sample_extract<__uint128_t, AmortizedDegree<512>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
|
||||
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
|
||||
break;
|
||||
case 1024:
|
||||
host_sample_extract<__uint128_t, AmortizedDegree<1024>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
|
||||
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
|
||||
break;
|
||||
case 2048:
|
||||
host_sample_extract<__uint128_t, AmortizedDegree<2048>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
|
||||
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
|
||||
break;
|
||||
case 4096:
|
||||
host_sample_extract<__uint128_t, AmortizedDegree<4096>>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
|
||||
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: unsupported polynomial size. Supported "
|
||||
"N's are powers of two in the interval [256..4096].")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include "fast_packing_keyswitch.cuh"
|
||||
#include "keyswitch.cuh"
|
||||
#include "keyswitch.h"
|
||||
#include <cstdint>
|
||||
#include <stdio.h>
|
||||
#include "keyswitch/keyswitch.h"
|
||||
#include "packing_keyswitch.cuh"
|
||||
|
||||
/* Perform keyswitch on a batch of 32 bits input LWE ciphertexts.
|
||||
* Head out to the equivalent operation on 64 bits for more details.
|
||||
@@ -73,7 +71,7 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_lwes) {
|
||||
|
||||
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
|
||||
host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
static_cast<uint64_t *>(glwe_array_out),
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
@@ -90,3 +88,31 @@ void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
|
||||
static_cast<cudaStream_t>(stream),
|
||||
gpu_index, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t num_lwes, bool allocate_gpu_memory) {
|
||||
scratch_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index, fp_ks_buffer, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
|
||||
* ciphertexts.
|
||||
*/
|
||||
|
||||
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
void *stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
|
||||
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_lwes) {
|
||||
host_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
static_cast<__uint128_t *>(glwe_array_out),
|
||||
static_cast<const __uint128_t *>(lwe_array_in),
|
||||
static_cast<const __uint128_t *>(fp_ksk_array), fp_ks_buffer,
|
||||
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
|
||||
base_log, level_count, num_lwes);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
|
||||
// Initialize decomposition by performing rounding
|
||||
// and decomposing one level of an array of Torus LWEs. Only
|
||||
// decomposes the mask elements of the incoming LWEs.
|
||||
template <typename Torus, typename TorusVec>
|
||||
template <typename Torus>
|
||||
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
|
||||
uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
@@ -63,7 +63,7 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
|
||||
// Continue decomposiion of an array of Torus elements in place. Supposes
|
||||
// that the array contains already decomposed elements and
|
||||
// computes the new decomposed level in place.
|
||||
template <typename Torus, typename TorusVec>
|
||||
template <typename Torus>
|
||||
__global__ void
|
||||
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
@@ -101,7 +101,7 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
|
||||
// This code is adapted by generalizing the 1d block-tiling
|
||||
// kernel from https://github.com/siboehm/SGEMM_CUDA
|
||||
// to any matrix dimension
|
||||
template <typename Torus, typename TorusVec>
|
||||
template <typename Torus>
|
||||
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
|
||||
int stride_B, Torus *C) {
|
||||
|
||||
@@ -251,8 +251,8 @@ __global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C(
|
||||
degree, coeffIdx, polynomial_size, 1, true);
|
||||
}
|
||||
|
||||
template <typename Torus, typename TorusVec>
|
||||
__host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
|
||||
template <typename Torus>
|
||||
__host__ void host_packing_keyswitch_lwe_list_to_glwe(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
|
||||
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
@@ -296,10 +296,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
|
||||
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
||||
|
||||
// decompose first level
|
||||
decompose_vectorize_init<Torus, TorusVec>
|
||||
<<<grid_decomp, threads_decomp, 0, stream>>>(lwe_array_in, d_mem_0,
|
||||
lwe_dimension, num_lwes,
|
||||
base_log, level_count);
|
||||
decompose_vectorize_init<Torus><<<grid_decomp, threads_decomp, 0, stream>>>(
|
||||
lwe_array_in, d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// gemm to ks the individual LWEs to GLWEs
|
||||
@@ -310,7 +308,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
|
||||
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
|
||||
|
||||
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
|
||||
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
|
||||
stride_KSK_buffer, d_mem_1);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
@@ -318,15 +316,14 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
|
||||
auto ksk_block_size = glwe_accumulator_size;
|
||||
|
||||
for (int li = 1; li < level_count; ++li) {
|
||||
decompose_vectorize_step_inplace<Torus, TorusVec>
|
||||
decompose_vectorize_step_inplace<Torus>
|
||||
<<<grid_decomp, threads_decomp, 0, stream>>>(
|
||||
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
tgemm<Torus, TorusVec>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
|
||||
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
|
||||
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
|
||||
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CUDA_INTEGER_COMPRESSION_CUH
|
||||
|
||||
#include "ciphertext.h"
|
||||
#include "crypto/fast_packing_keyswitch.cuh"
|
||||
#include "crypto/keyswitch.cuh"
|
||||
#include "crypto/packing_keyswitch.cuh"
|
||||
#include "device.h"
|
||||
#include "integer/compression/compression.h"
|
||||
#include "integer/compression/compression_utilities.h"
|
||||
@@ -116,7 +116,7 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
while (rem_lwes > 0) {
|
||||
auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe);
|
||||
|
||||
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
|
||||
host_packing_keyswitch_lwe_list_to_glwe<Torus>(
|
||||
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
|
||||
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
|
||||
compression_params.polynomial_size, compression_params.ks_base_log,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef SETUP_AND_TEARDOWN_H
|
||||
#define SETUP_AND_TEARDOWN_H
|
||||
|
||||
#include "keyswitch/keyswitch.h"
|
||||
#include "pbs/programmable_bootstrap.h"
|
||||
#include "pbs/programmable_bootstrap_multibit.h"
|
||||
#include <device.h>
|
||||
#include <keyswitch.h>
|
||||
#include <utils.h>
|
||||
|
||||
void programmable_bootstrap_classical_setup(
|
||||
|
||||
@@ -60,6 +60,19 @@ unsafe extern "C" {
|
||||
log_modulus: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_glwe_sample_extract_128(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
lwe_array_out: *mut ffi::c_void,
|
||||
glwe_array_in: *const ffi::c_void,
|
||||
nth_array: *const u32,
|
||||
num_nths: u32,
|
||||
lwe_per_glwe: u32,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
);
|
||||
}
|
||||
pub const PBS_TYPE_MULTI_BIT: PBS_TYPE = 0;
|
||||
pub const PBS_TYPE_CLASSICAL: PBS_TYPE = 1;
|
||||
pub type PBS_TYPE = ffi::c_uint;
|
||||
@@ -1429,6 +1442,34 @@ unsafe extern "C" {
|
||||
num_lwes: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
fp_ks_buffer: *mut *mut i8,
|
||||
lwe_dimension: u32,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
num_lwes: u32,
|
||||
allocate_gpu_memory: bool,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
glwe_array_out: *mut ffi::c_void,
|
||||
lwe_array_in: *const ffi::c_void,
|
||||
fp_ksk_array: *const ffi::c_void,
|
||||
fp_ks_buffer: *mut i8,
|
||||
input_lwe_dimension: u32,
|
||||
output_glwe_dimension: u32,
|
||||
output_polynomial_size: u32,
|
||||
base_log: u32,
|
||||
level_count: u32,
|
||||
num_lwes: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_packing_keyswitch_lwe_list_to_glwe(
|
||||
stream: *mut ffi::c_void,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "cuda/include/integer/compression/compression.h"
|
||||
#include "cuda/include/integer/integer.h"
|
||||
#include "cuda/include/zk/zk.h"
|
||||
#include "cuda/include/keyswitch.h"
|
||||
#include "cuda/include/keyswitch/keyswitch.h"
|
||||
#include "cuda/include/keyswitch/ks_enums.h"
|
||||
#include "cuda/include/linear_algebra.h"
|
||||
#include "cuda/include/fft/fft128.h"
|
||||
|
||||
@@ -497,7 +497,7 @@ mod cuda {
|
||||
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use tfhe::core_crypto::gpu::{
|
||||
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext,
|
||||
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64,
|
||||
get_number_of_gpus, CudaStreams,
|
||||
};
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
@@ -796,7 +796,7 @@ mod cuda {
|
||||
{
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
|
||||
gpu_keys.pksk.as_ref().unwrap(),
|
||||
&d_input_lwe_list,
|
||||
&mut d_output_glwe,
|
||||
@@ -879,7 +879,7 @@ mod cuda {
|
||||
((i, input_lwe_list), output_glwe_list),
|
||||
local_stream,
|
||||
)| {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
|
||||
gpu_keys_vec[i].pksk.as_ref().unwrap(),
|
||||
input_lwe_list,
|
||||
output_glwe_list,
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_packing_keyswitch_key::CudaLwePackingKeyswitchKey;
|
||||
use crate::core_crypto::gpu::{packing_keyswitch_list_async, CudaStreams};
|
||||
use crate::core_crypto::gpu::{
|
||||
packing_keyswitch_list_128_async, packing_keyswitch_list_64_async, CudaStreams,
|
||||
};
|
||||
use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronised
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scalar>(
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async<Scalar>(
|
||||
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
|
||||
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
|
||||
@@ -42,7 +44,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scal
|
||||
lwe_pksk.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
|
||||
packing_keyswitch_list_async(
|
||||
packing_keyswitch_list_64_async(
|
||||
streams,
|
||||
&mut output_glwe_ciphertext.0.d_vec,
|
||||
&input_lwe_ciphertext_list.0.d_vec,
|
||||
@@ -56,7 +58,7 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async<Scal
|
||||
);
|
||||
}
|
||||
|
||||
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext<Scalar>(
|
||||
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64<Scalar>(
|
||||
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
|
||||
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
|
||||
@@ -66,7 +68,78 @@ pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext<Scalar>(
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
{
|
||||
unsafe {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
|
||||
lwe_pksk,
|
||||
input_lwe_ciphertext_list,
|
||||
output_glwe_ciphertext,
|
||||
streams,
|
||||
);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronised
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128_async<Scalar>(
|
||||
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
|
||||
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
{
|
||||
let input_lwe_dimension = input_lwe_ciphertext_list.lwe_dimension();
|
||||
let output_glwe_dimension = output_glwe_ciphertext.glwe_dimension();
|
||||
let output_polynomial_size = output_glwe_ciphertext.polynomial_size();
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
input_lwe_ciphertext_list.0.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first input pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
input_lwe_ciphertext_list.0.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
output_glwe_ciphertext.0.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
output_glwe_ciphertext.0.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
lwe_pksk.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first pksk pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_pksk.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
|
||||
packing_keyswitch_list_128_async(
|
||||
streams,
|
||||
&mut output_glwe_ciphertext.0.d_vec,
|
||||
&input_lwe_ciphertext_list.0.d_vec,
|
||||
input_lwe_dimension,
|
||||
output_glwe_dimension,
|
||||
output_polynomial_size,
|
||||
&lwe_pksk.d_vec,
|
||||
lwe_pksk.decomposition_base_log(),
|
||||
lwe_pksk.decomposition_level_count(),
|
||||
input_lwe_ciphertext_list.lwe_ciphertext_count(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128<Scalar>(
|
||||
lwe_pksk: &CudaLwePackingKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext_list: &CudaLweCiphertextList<Scalar>,
|
||||
output_glwe_ciphertext: &mut CudaGlweCiphertextList<Scalar>,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
{
|
||||
unsafe {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_128_async(
|
||||
lwe_pksk,
|
||||
input_lwe_ciphertext_list,
|
||||
output_glwe_ciphertext,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::gpu::algorithms::lwe_packing_keyswitch::cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async;
|
||||
use crate::core_crypto::gpu::algorithms::lwe_packing_keyswitch::cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async;
|
||||
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::vec::GpuIndex;
|
||||
@@ -104,7 +104,7 @@ where
|
||||
);
|
||||
|
||||
unsafe {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
|
||||
&pksk,
|
||||
&d_input_lwe,
|
||||
&mut d_output_glwe,
|
||||
@@ -197,7 +197,7 @@ where
|
||||
);
|
||||
|
||||
unsafe {
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_async(
|
||||
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64_async(
|
||||
&pksk,
|
||||
&d_input_lwe_list,
|
||||
&mut d_output_glwe,
|
||||
|
||||
@@ -367,7 +367,7 @@ pub unsafe fn convert_lwe_keyswitch_key_async<T: UnsignedInteger>(
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn packing_keyswitch_list_async<T: UnsignedInteger>(
|
||||
pub unsafe fn packing_keyswitch_list_64_async<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
glwe_array_out: &mut CudaVec<T>,
|
||||
lwe_array_in: &CudaVec<T>,
|
||||
@@ -412,6 +412,58 @@ pub unsafe fn packing_keyswitch_list_async<T: UnsignedInteger>(
|
||||
);
|
||||
}
|
||||
|
||||
/// Applies packing keyswitch on a vector of 128-bit LWE ciphertexts
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn packing_keyswitch_list_128_async<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
glwe_array_out: &mut CudaVec<T>,
|
||||
lwe_array_in: &CudaVec<T>,
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_glwe_dimension: GlweDimension,
|
||||
output_polynomial_size: PolynomialSize,
|
||||
packing_keyswitch_key: &CudaVec<T>,
|
||||
base_log: DecompositionBaseLog,
|
||||
l_gadget: DecompositionLevelCount,
|
||||
num_lwes: LweCiphertextCount,
|
||||
) {
|
||||
let mut fp_ks_buffer: *mut i8 = std::ptr::null_mut();
|
||||
scratch_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
std::ptr::addr_of_mut!(fp_ks_buffer),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_glwe_dimension.0 as u32,
|
||||
output_polynomial_size.0 as u32,
|
||||
num_lwes.0 as u32,
|
||||
true,
|
||||
);
|
||||
cuda_packing_keyswitch_lwe_list_to_glwe_128(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
glwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
packing_keyswitch_key.as_c_ptr(0),
|
||||
fp_ks_buffer,
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_glwe_dimension.0 as u32,
|
||||
output_polynomial_size.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_lwes.0 as u32,
|
||||
);
|
||||
cleanup_packing_keyswitch_lwe_list_to_glwe(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
std::ptr::addr_of_mut!(fp_ks_buffer),
|
||||
true,
|
||||
);
|
||||
}
|
||||
|
||||
/// Convert programmable bootstrap key
|
||||
///
|
||||
/// # Safety
|
||||
@@ -440,7 +492,7 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
|
||||
l_gadget.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
);
|
||||
} else {
|
||||
} else if size_of::<T>() == 8 {
|
||||
cuda_convert_lwe_programmable_bootstrap_key_64(
|
||||
stream_ptr,
|
||||
streams.gpu_indexes[i].get(),
|
||||
@@ -451,6 +503,8 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
|
||||
l_gadget.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported torus size for bsk conversion")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -504,17 +558,33 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
) {
|
||||
cuda_glwe_sample_extract_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
glwe_array_in.as_c_ptr(0),
|
||||
nth_array.as_c_ptr(0).cast::<u32>(),
|
||||
num_nths,
|
||||
lwe_per_glwe,
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
);
|
||||
if size_of::<T>() == 16 {
|
||||
cuda_glwe_sample_extract_128(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
glwe_array_in.as_c_ptr(0),
|
||||
nth_array.as_c_ptr(0).cast::<u32>(),
|
||||
num_nths,
|
||||
lwe_per_glwe,
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
);
|
||||
} else if size_of::<T>() == 8 {
|
||||
cuda_glwe_sample_extract_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
glwe_array_in.as_c_ptr(0),
|
||||
nth_array.as_c_ptr(0).cast::<u32>(),
|
||||
num_nths,
|
||||
lwe_per_glwe,
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
);
|
||||
} else {
|
||||
panic!("Unsupported torus size for glwe sample extraction")
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
|
||||
Reference in New Issue
Block a user