chore(gpu): bench KS latency batches

This commit is contained in:
Andrei Stoian
2025-11-06 18:39:20 +01:00
committed by Andrei Stoian
parent d6a0a366b9
commit e2063c8ef4
24 changed files with 1239 additions and 269 deletions

View File

@@ -776,7 +776,7 @@ build_debug_integer_short_run_gpu: install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile debug_lto_off \
--features=integer,gpu-debug-fake-multi-gpu -p tfhe -- integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random --list
@echo "To debug fake-multi-gpu short run tests run:"
@echo "TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
@echo "TFHE_RS_LONGRUN_TESTS_SEED=<SEED_FROM_CI> TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
@echo "Where <executable> = the one printed in the () in the 'Running unittests src/lib.rs ()' line above"
.PHONY: test_integer_compression

View File

@@ -37,6 +37,7 @@ template <typename Torus> struct int_aes_lut_buffers {
auto active_streams_and_lut = streams.active_gpu_subset(
SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism);
this->and_lut->broadcast_lut(active_streams_and_lut);
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
this->flush_lut = new int_radix_lut<Torus>(
streams, params, 1, AES_STATE_BITS * num_aes_inputs,
@@ -52,6 +53,7 @@ template <typename Torus> struct int_aes_lut_buffers {
auto active_streams_flush_lut =
streams.active_gpu_subset(AES_STATE_BITS * num_aes_inputs);
this->flush_lut->broadcast_lut(active_streams_flush_lut);
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
this->carry_lut = new int_radix_lut<Torus>(
streams, params, 1, num_aes_inputs, allocate_gpu_memory, size_tracker);
@@ -65,6 +67,7 @@ template <typename Torus> struct int_aes_lut_buffers {
params.carry_modulus, carry_lambda, allocate_gpu_memory);
auto active_streams_carry_lut = streams.active_gpu_subset(num_aes_inputs);
this->carry_lut->broadcast_lut(active_streams_carry_lut);
this->carry_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
}
void release(CudaStreams streams) {

View File

@@ -13,6 +13,8 @@
#include <stdio.h>
#include "crypto/keyswitch.cuh"
class NoiseLevel {
public:
// Constants equivalent to the Rust code
@@ -336,6 +338,9 @@ struct int_radix_lut_custom_input_output {
std::vector<InputTorus *> lwe_after_ks_vec;
std::vector<OutputTorus *> lwe_after_pbs_vec;
std::vector<InputTorus *> lwe_trivial_indexes_vec;
std::vector<ks_mem<InputTorus> *>
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
std::vector<InputTorus *> lwe_aligned_vec;
bool gpu_memory_allocated;
@@ -443,6 +448,30 @@ struct int_radix_lut_custom_input_output {
allocate_gpu_memory);
}
void setup_gemm_batch_ks_temp_buffers(uint64_t &size_tracker) {
auto inputs_on_gpu =
std::min((int)num_input_blocks,
std::max(THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(num_input_blocks, 0,
active_streams.count())));
if (inputs_on_gpu >= get_threshold_ks_gemm()) {
for (auto i = 0; i < active_streams.count(); ++i) {
ks_mem<InputTorus> *ks_buffer;
uint64_t sub_size_tracker = scratch_cuda_keyswitch<InputTorus>(
active_streams.stream(i), active_streams.gpu_index(i), &ks_buffer,
input_big_lwe_dimension, params.small_lwe_dimension, num_blocks,
gpu_memory_allocated);
if (i == 0) {
size_tracker += sub_size_tracker;
}
ks_tmp_buf_vec.push_back(ks_buffer);
}
}
}
void setup_mem_reuse(uint32_t num_radix_blocks,
int_radix_lut_custom_input_output *base_lut_object) {
// base lut object should have bigger or equal memory than current one
@@ -461,6 +490,8 @@ struct int_radix_lut_custom_input_output {
lwe_after_pbs_vec = base_lut_object->lwe_after_pbs_vec;
lwe_trivial_indexes_vec = base_lut_object->lwe_trivial_indexes_vec;
ks_tmp_buf_vec = base_lut_object->ks_tmp_buf_vec;
mem_reuse = true;
}
@@ -865,6 +896,13 @@ struct int_radix_lut_custom_input_output {
}
lwe_aligned_vec.clear();
}
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
cleanup_cuda_keyswitch(active_streams.stream(i),
active_streams.gpu_index(i), ks_tmp_buf_vec[i],
gpu_memory_allocated);
}
ks_tmp_buf_vec.clear();
}
free(h_lut_indexes);
free(degrees);

View File

@@ -15,6 +15,10 @@ template <typename Torus> struct int_rerand_mem {
bool gpu_memory_allocated;
std::vector<ks_mem<Torus> *>
ks_tmp_buf_vec; // not allocated, ReRand not using GEMM KS for now
// kept empty to pass to the KS function indicating GEMM KS disabled
expand_job<Torus> *d_expand_jobs;
expand_job<Torus> *h_expand_jobs;
@@ -72,6 +76,13 @@ template <typename Torus> struct int_rerand_mem {
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
streams.gpu_index(0),
gpu_memory_allocated);
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
cleanup_cuda_keyswitch(streams.stream(i), streams.gpu_index(i),
ks_tmp_buf_vec[i], gpu_memory_allocated);
}
ks_tmp_buf_vec.clear();
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
free(h_expand_jobs);
}

View File

@@ -12,6 +12,13 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples);
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
void cuda_keyswitch_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
@@ -24,6 +31,17 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t num_lwes, bool allocate_gpu_memory);
uint64_t scratch_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
void **ks_tmp_memory,
uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out,
uint32_t num_lwes,
bool allocate_gpu_memory);
void cleanup_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
void **ks_tmp_memory,
bool allocate_gpu_memory);
void cuda_packing_keyswitch_lwe_list_to_glwe_64(
void *stream, uint32_t gpu_index, void *glwe_array_out,
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,

View File

@@ -9,14 +9,16 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
host_keyswitch_lwe_ciphertext_vector<uint32_t>(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
void *ksk_tmp_buffer, bool uses_trivial_indices) {
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint32_t *>(lwe_array_out),
static_cast<uint32_t *>(lwe_output_indexes),
static_cast<uint32_t *>(lwe_array_in),
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
}
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
@@ -35,6 +37,26 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
* This function calls a wrapper to a device kernel that performs the keyswitch
* - num_samples blocks of threads are launched
*/
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, const void *ks_tmp_buffer,
bool uses_trivial_indices) {
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(lwe_array_out),
static_cast<const uint64_t *>(lwe_output_indexes),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(lwe_input_indexes),
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
base_log, level_count, num_samples,
static_cast<const ks_mem<uint64_t> *>(ks_tmp_buffer)->d_buffer,
uses_trivial_indices);
}
void cuda_keyswitch_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lwe_array_in,
@@ -44,10 +66,10 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
host_keyswitch_lwe_ciphertext_vector<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(lwe_array_out),
static_cast<const uint64_t *>(lwe_output_indexes),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(lwe_input_indexes),
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
static_cast<uint64_t const *>(lwe_output_indexes),
static_cast<uint64_t const *>(lwe_array_in),
static_cast<uint64_t const *>(lwe_input_indexes),
static_cast<uint64_t const *>(ksk), lwe_dimension_in, lwe_dimension_out,
base_log, level_count, num_samples);
}
@@ -60,6 +82,27 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
}
uint64_t scratch_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
void **ks_tmp_buffer,
uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out,
uint32_t num_lwes,
bool allocate_gpu_memory) {
return scratch_cuda_keyswitch<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
(ks_mem<uint64_t> **)ks_tmp_buffer, lwe_dimension_in, lwe_dimension_out,
num_lwes, allocate_gpu_memory);
}
void cleanup_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
void **ks_tmp_buffer,
bool allocate_gpu_memory) {
cleanup_cuda_keyswitch<uint64_t>(static_cast<cudaStream_t>(stream), gpu_index,
(ks_mem<uint64_t> *)*ks_tmp_buffer,
allocate_gpu_memory);
*ks_tmp_buffer = nullptr;
}
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
* ciphertexts.
*/

View File

@@ -1,6 +1,8 @@
#ifndef CNCRT_KS_CUH
#define CNCRT_KS_CUH
#include <algorithm>
#include "device.h"
#include "gadget.cuh"
#include "helper_multi_gpu.h"
@@ -10,8 +12,30 @@
#include "utils/helper.cuh"
#include "utils/kernel_dimensions.cuh"
#include <thread>
#include <unistd.h>
#include <vector>
const int BLOCK_SIZE_DECOMP = 8;
const int BLOCK_SIZE_GEMM_KS = 36;
const int THREADS_GEMM_KS = 6;
inline uint64_t get_threshold_ks_gemm() { return 128; }
template <typename Torus> struct ks_mem {
Torus *d_buffer;
uint64_t num_lwes;
uint32_t lwe_dimension;
};
template <typename Torus>
uint64_t scratch_cuda_keyswitch_size(uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out,
uint32_t num_lwes) {
GPU_ASSERT(lwe_dimension_in >= lwe_dimension_out,
"Trying to allocate KS temp buffer for invalid LWE dimensions");
return (uint64_t)num_lwes * lwe_dimension_in * sizeof(Torus) * 2;
}
template <typename Torus>
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
uint32_t lwe_dimension_out,
@@ -22,6 +46,169 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
return ptr;
}
// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
template <typename Torus>
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {
// index of this LWE ct in the buffer
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
// we only decompose the mask. Thus the stride for reading
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
Torus a_i = lwe_in[read_val_idx];
Torus state = init_decomposer_state(a_i, base_log, level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
}
// Decompose an array of LWEs with indirection through lwe_input_indices
// The LWE array can contain total_lwe LWEs where total_lwe can be different
// from num_lwe. The maximum index should be <= total_lwe. num_lwe is the number
// of LWEs to decompose The output buffer should have space for num_lwe LWEs.
// These will be sorted according to the input indices.
template <typename Torus>
__global__ void decompose_vectorize_init_with_indices(
Torus const *lwe_in, const Torus *__restrict__ lwe_input_indices,
Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {
// index of this LWE ct in the buffer
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
// we only decompose the mask. Thus the stride for reading
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx =
lwe_input_indices[lwe_idx] * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
Torus a_i = lwe_in[read_val_idx];
Torus state = init_decomposer_state(a_i, base_log, level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
}
// Continue decomposition of an array of Torus elements in place. Supposes
// that the array contains already decomposed elements and
// computes the new decomposed level in place.
template <typename Torus>
__global__ void
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {
// index of this LWE ct in the buffer
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto state_idx = num_lwe * lwe_dimension + val_idx;
Torus state = buffer_in[state_idx];
__syncthreads();
Torus mod_b_mask = (1ll << base_log) - 1ll;
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
buffer_in[state_idx] = state;
}
template <typename Torus>
__global__ void keyswitch_gemm_copy_negated_message_with_indices(
const Torus *__restrict__ lwe_in,
const Torus *__restrict__ lwe_input_indices, Torus *__restrict__ lwe_out,
const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
uint32_t lwe_id = blockIdx.x * blockDim.x + threadIdx.x;
if (lwe_id >= num_lwes)
return;
uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
lwe_out[lwe_out_idx * (lwe_dimension_out + 1) + lwe_dimension_out] =
-lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
}
// Finishes the KS computation by negating all elements in the array
// using output indices. The array contains -b + SUM(a_i x LWE_i)
// and this final step computes b - SUM(a_i x LWE_i)
template <typename Torus>
__global__ void keyswitch_negate_with_output_indices(
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_size, uint32_t num_lwe) {
// index of this LWE ct in the buffer
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
return;
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
Torus val = buffer_in[val_idx];
buffer_in[val_idx] = -val;
}
template <typename Torus>
__global__ void keyswitch_zero_output_with_output_indices(
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
uint32_t lwe_size, uint32_t num_lwe) {
// index of this LWE ct in the buffer
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
return;
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
buffer_in[val_idx] = 0;
}
/*
* keyswitch kernel
* Each thread handles a piece of the following equation:
@@ -142,14 +329,141 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
}
template <typename Torus>
void execute_keyswitch_async(CudaStreams streams,
const LweArrayVariant<Torus> &lwe_array_out,
__host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus const *lwe_output_indices, Torus const *lwe_array_in,
Torus const *lwe_input_indices, Torus const *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices) {
cuda_set_device(gpu_index);
check_cuda_error(cudaGetLastError());
auto d_mem_0 = fp_tmp_buffer; // keeps decomposed value
// Set the scratch buffer to 0 as it is used to accumulate
// decomposition temporary results
if (uses_trivial_indices) {
cuda_memset_async(lwe_array_out, 0,
num_samples * (lwe_dimension_out + 1) * sizeof(Torus),
stream, gpu_index);
} else {
// gemm to ks the individual LWEs to GLWEs
dim3 grid_zero(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_zero(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
keyswitch_zero_output_with_output_indices<Torus>
<<<grid_zero, threads_zero, 0, stream>>>(
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
num_samples);
}
check_cuda_error(cudaGetLastError());
dim3 grid_copy(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_copy(BLOCK_SIZE_DECOMP);
// lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
// lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
// and negate
keyswitch_gemm_copy_negated_message_with_indices<Torus>
<<<grid_copy, threads_copy, 0, stream>>>(
lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
lwe_dimension_in, num_samples, lwe_dimension_out);
check_cuda_error(cudaGetLastError());
// decompose LWEs
// don't decompose LWE body - the LWE has lwe_size + 1 elements. The last
// element, the body is ignored by rounding down the number of blocks assuming
// here that the LWE dimension is a multiple of the block size
dim3 grid_decomp(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP),
CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP));
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
// 32, 64, and 128-bit Torus elements
// Sanity check: the shared memory size is a constant defined by the algorithm
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
"GEMM kernel error: shared memory required might be too large");
auto stride_KSK_buffer = (lwe_dimension_out + 1) * level_count;
// gemm to ks the individual LWEs to GLWEs
dim3 grid_gemm(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_GEMM_KS),
CEIL_DIV(num_samples, BLOCK_SIZE_GEMM_KS));
dim3 threads_gemm(BLOCK_SIZE_GEMM_KS * THREADS_GEMM_KS);
// decompose first level (skips the body in the input buffer)
decompose_vectorize_init_with_indices<Torus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
lwe_array_in, lwe_input_indices, fp_tmp_buffer, lwe_dimension_in,
num_samples, base_log, level_count);
check_cuda_error(cudaGetLastError());
if (uses_trivial_indices) {
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
check_cuda_error(cudaGetLastError());
} else {
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1,
lwe_output_indices);
check_cuda_error(cudaGetLastError());
}
auto ksk_block_size = (lwe_dimension_out + 1);
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus>
<<<grid_decomp, threads_decomp, 0, stream>>>(
fp_tmp_buffer, lwe_dimension_in, num_samples, base_log,
level_count);
check_cuda_error(cudaGetLastError());
if (uses_trivial_indices) {
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
lwe_dimension_out + 1);
check_cuda_error(cudaGetLastError());
} else {
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
lwe_dimension_out + 1, lwe_output_indices);
check_cuda_error(cudaGetLastError());
}
}
// gemm to ks the individual LWEs to GLWEs
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
// Negate all outputs in the LWE
keyswitch_negate_with_output_indices<Torus>
<<<grid_negate, threads_negate, 0, stream>>>(
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
num_samples);
check_cuda_error(cudaGetLastError());
}
template <typename Torus>
void execute_keyswitch_async(
CudaStreams streams, const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
const LweArrayVariant<Torus> &lwe_array_in,
const LweArrayVariant<Torus> &lwe_input_indexes,
Torus *const *ksks, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
const LweArrayVariant<Torus> &lwe_input_indexes, Torus *const *ksks,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, bool uses_trivial_indices,
const std::vector<ks_mem<Torus> *> &fp_tmp_buffer) {
/// If the number of radix blocks is lower than the number of GPUs, not all
/// GPUs will be active and there will be 1 input per GPU
@@ -164,12 +478,39 @@ void execute_keyswitch_async(CudaStreams streams,
Torus *current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
if (!fp_tmp_buffer.empty() &&
num_samples_on_gpu >= get_threshold_ks_gemm()) {
GPU_ASSERT(fp_tmp_buffer.size() >= streams.count(),
"GEMM KS Buffers %ld were not initialized for this amount of "
"streams, %d",
fp_tmp_buffer.size(), streams.count());
GPU_ASSERT(fp_tmp_buffer[i]->num_lwes >= num_samples_on_gpu,
"KS temp buffer not big enough");
GPU_ASSERT(fp_tmp_buffer[i]->lwe_dimension ==
std::max(lwe_dimension_in, lwe_dimension_out),
"KS temp buffer was created for a different input LWE size: "
"%d vs (in:%d, out:%d)",
fp_tmp_buffer[i]->lwe_dimension, lwe_dimension_in,
lwe_dimension_out);
// Compute Keyswitch
host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
current_lwe_output_indexes, current_lwe_array_in,
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
fp_tmp_buffer[i]->d_buffer, uses_trivial_indices);
} else {
// Compute Keyswitch
host_keyswitch_lwe_ciphertext_vector<Torus>(
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
current_lwe_output_indexes, current_lwe_array_in,
current_lwe_input_indexes, ksks[i], lwe_dimension_in, lwe_dimension_out,
base_log, level_count, num_samples_on_gpu);
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);
}
}
}
@@ -259,4 +600,32 @@ __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
}
}
template <typename Torus>
uint64_t scratch_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
ks_mem<Torus> **ks_tmp_memory,
uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t num_lwes,
bool allocate_gpu_memory) {
uint64_t sub_size_tracker = 0;
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
lwe_dimension_in, lwe_dimension_out, num_lwes);
*ks_tmp_memory = new ks_mem<Torus>;
(*ks_tmp_memory)->d_buffer = (uint64_t *)cuda_malloc_with_size_tracking_async(
buffer_size, stream, gpu_index, sub_size_tracker, allocate_gpu_memory);
(*ks_tmp_memory)->lwe_dimension =
std::max(lwe_dimension_in, lwe_dimension_out);
(*ks_tmp_memory)->num_lwes = num_lwes;
return sub_size_tracker;
}
template <typename Torus>
void cleanup_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
ks_mem<Torus> *ks_tmp_memory,
bool allocate_gpu_memory) {
cuda_drop_with_size_tracking_async(ks_tmp_memory->d_buffer, stream, gpu_index,
allocate_gpu_memory);
delete ks_tmp_memory;
}
#endif

View File

@@ -18,73 +18,6 @@
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)
const int BLOCK_SIZE_DECOMP = 8;
// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
template <typename Torus>
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {
// index of this LWE ct in the buffer
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
// we only decompose the mask. Thus the stride for reading
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
Torus a_i = lwe_in[read_val_idx];
Torus state = init_decomposer_state(a_i, base_log, level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
}
// Continue decomposition of an array of Torus elements in place. Supposes
// that the array contains already decomposed elements and
// computes the new decomposed level in place.
template <typename Torus>
__global__ void
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {
// index of this LWE ct in the buffer
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto state_idx = num_lwe * lwe_dimension + val_idx;
Torus state = buffer_in[state_idx];
__syncthreads();
Torus mod_b_mask = (1ll << base_log) - 1ll;
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
buffer_in[state_idx] = state;
}
// Finish the keyswitching operation and prepare GLWEs for accumulation.
// 1. Finish the keyswitching computation partially performed with a GEMM:
// - negate the dot product between the GLWE and KSK polynomial
@@ -209,7 +142,8 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
"GEMM kernel error: shared memory required might be too large");
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1, glwe_accumulator_size);
check_cuda_error(cudaGetLastError());
@@ -222,7 +156,8 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
glwe_accumulator_size);

View File

@@ -547,7 +547,7 @@ __host__ void integer_radix_apply_univariate_lookup_table(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -575,7 +575,8 @@ __host__ void integer_radix_apply_univariate_lookup_table(
execute_keyswitch_async<Torus>(
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -649,7 +650,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -677,7 +678,8 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
execute_keyswitch_async<Torus>(
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -767,7 +769,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -792,7 +794,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
execute_keyswitch_async<Torus>(
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -1521,7 +1524,8 @@ void host_full_propagate_inplace(CudaStreams streams,
streams.get_ith(0), (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr,
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1);
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1,
mem_ptr->lut->using_trivial_lwe_indexes, mem_ptr->lut->ks_tmp_buf_vec);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_small_lwe_vector,
@@ -2356,7 +2360,8 @@ __host__ void integer_radix_apply_noise_squashing(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
ks_level, lwe_array_out->num_radix_blocks);
ks_level, lwe_array_out->num_radix_blocks,
lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -2386,7 +2391,7 @@ __host__ void integer_radix_apply_noise_squashing(
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks,
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
ks_level, lwe_array_out->num_radix_blocks);
ks_level, lwe_array_out->num_radix_blocks, true, lut->ks_tmp_buf_vec);
/// int_noise_squashing_lut doesn't support a different output or lut
/// indexing than the trivial

View File

@@ -389,13 +389,16 @@ __host__ void host_integer_partial_sum_ciphertexts_vec(
needs_processing);
auto active_streams = streams.active_gpu_subset(total_ciphertexts);
GPU_ASSERT(total_ciphertexts <= mem_ptr->luts_message_carry->num_blocks,
"SUM CT");
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
(Torus *)current_blocks->ptr, d_pbs_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, total_messages);
mem_ptr->params.ks_level, total_messages, false,
mem_ptr->luts_message_carry->ks_tmp_buf_vec);
execute_pbs_async<Torus, Torus>(
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
@@ -446,7 +449,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec(
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
(Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, num_radix_blocks);
mem_ptr->params.ks_level, num_radix_blocks, false,
mem_ptr->luts_message_carry->ks_tmp_buf_vec);
execute_pbs_async<Torus, Torus>(
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,

View File

@@ -59,7 +59,7 @@ void rerand_inplace(
execute_keyswitch_async<Torus>(
streams.get_ith(0), ksed_zero_lwes, lwe_trivial_indexes, zero_lwes,
lwe_trivial_indexes, ksk, input_dimension, output_dimension, ks_base_log,
ks_level, num_lwes);
ks_level, num_lwes, true, mem_ptr->ks_tmp_buf_vec);
// Add ks output to ct
// Check sizes

View File

@@ -102,14 +102,14 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
// This code is adapted by generalizing the 1d block-tiling
// kernel from https://github.com/siboehm/SGEMM_CUDA
// to any matrix dimension
template <typename Torus>
template <typename Torus, int BLOCK_SIZE, int THREADS>
__global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
uint stride_B, Torus *C, uint stride_C) {
const int BM = BLOCK_SIZE_GEMM;
const int BN = BLOCK_SIZE_GEMM;
const int BK = THREADS_GEMM;
const int TM = THREADS_GEMM;
const int BM = BLOCK_SIZE;
const int BN = BLOCK_SIZE;
const int BK = THREADS;
const int TM = THREADS;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
@@ -192,4 +192,103 @@ __global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
}
}
// Multiply matrices A, B of size (M, K), (K, N) respectively
// with K as the inner dimension.
//
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
//
// This code is adapted by generalizing the 1d block-tiling
// kernel from https://github.com/siboehm/SGEMM_CUDA
// to any matrix dimension
template <typename Torus, int BLOCK_SIZE, int THREADS>
__global__ void tgemm_with_indices(uint M, uint N, uint K, const Torus *A,
const Torus *B, uint stride_B, Torus *C,
uint stride_C,
const Torus *__restrict__ C_indices) {
const int BM = BLOCK_SIZE;
const int BN = BLOCK_SIZE;
const int BK = THREADS;
const int TM = THREADS;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const int threadCol = threadIdx.x % BN;
const int threadRow = threadIdx.x / BN;
// Allocate space for the current block tile in shared memory
__shared__ Torus As[BM * BK];
__shared__ Torus Bs[BK * BN];
// Initialize the pointers to the input blocks from A, B
// Tiles from these blocks are loaded to shared memory
A += cRow * BM * K;
B += cCol * BN;
// Each thread will handle multiple sub-blocks
const uint innerColA = threadIdx.x % BK;
const uint innerRowA = threadIdx.x / BK;
const uint innerColB = threadIdx.x % BN;
const uint innerRowB = threadIdx.x / BN;
// allocate thread-local cache for results in registerfile
Torus threadResults[TM] = {0};
auto row_A = cRow * BM + innerRowA;
auto col_B = cCol * BN + innerColB;
// For each thread, loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
auto col_A = bkIdx + innerColA;
auto row_B = bkIdx + innerRowB;
if (row_A < M && col_A < K) {
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; //
} else {
As[innerRowA * BK + innerColA] = 0;
}
if (col_B < N && row_B < K) {
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
} else {
Bs[innerRowB * BN + innerColB] = 0;
}
__syncthreads();
// Advance blocktile for the next iteration of this loop
A += BK;
B += BK * stride_B;
// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// we make the dotproduct loop the outside loop, which facilitates
// reuse of the Bs entry, which we can cache in a tmp var.
Torus tmp = Bs[dotIdx * BN + threadCol];
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
threadResults[resIdx] +=
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
}
}
__syncthreads();
}
// write out the results
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
int outRow = cRow * BM + threadRow * TM + resIdx;
int outCol = cCol * BN + threadCol;
if (outRow >= M)
continue;
if (outCol >= N)
continue;
C[C_indices[outRow] * stride_C + outCol] += threadResults[resIdx];
}
}
#endif // CUDA_MULT_H

View File

@@ -89,7 +89,8 @@ __host__ void host_wrapping_polynomial_mul_one_to_many(
PANIC("GEMM kernel error: shared memory required might be too large");
// Write the output with a stride of the GLWE total number of values
tgemm<Torus><<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
<<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
n_rhs, polynomial_size, polynomial_size, poly_rhs, (Torus *)circulant,
polynomial_size, result, (polynomial_size * (glwe_dimension + 1)));
check_cuda_error(cudaGetLastError());

View File

@@ -2,6 +2,7 @@
#define HELPER_CUH
#include <cstdint>
#include <sstream>
#include <stdio.h>
#include <type_traits>
@@ -64,4 +65,71 @@ void print_body(const char *name, T *src, int n, int lwe_dimension, T delta) {
printf("\n");
}
template <typename Torus>
void print_2d_csv_to_file(const std::vector<Torus> &v, int col_size,
const char *fname) {
FILE *fp = fopen(fname, "wt");
for (int i = 0; i < v.size() / col_size; ++i) {
for (int j = 0; j < col_size; ++j) {
fprintf(fp, "%lu%c", v[i * col_size + j],
(j == col_size - 1) ? '\n' : ',');
}
}
fclose(fp);
}
template <typename Torus>
__host__ void dump_2d_gpu_to_file(const Torus *ptr, int row_size, int col_size,
const char *fname_prefix, int rand_prefix,
cudaStream_t stream, uint32_t gpu_index) {
// #ifndef NDEBUG
std::vector<Torus> buf_cpu(row_size * col_size);
char fname[4096];
snprintf(fname, 4096, "%s_%d_%d_%d.csv", fname_prefix, row_size, col_size,
rand_prefix);
cuda_memcpy_async_to_cpu((void *)&buf_cpu[0], ptr,
buf_cpu.size() * sizeof(Torus), stream, gpu_index);
cuda_synchronize_device(gpu_index);
print_2d_csv_to_file(buf_cpu, col_size, fname);
// #endif
}
template <typename Torus>
__host__ void compare_2d_arrays(const Torus *ptr1, const Torus *ptr2,
int row_size, int col_size, cudaStream_t stream,
uint32_t gpu_index) {
// #ifndef NDEBUG
std::vector<Torus> buf_cpu1(row_size * col_size),
buf_cpu2(row_size * col_size);
;
cuda_memcpy_async_to_cpu((void *)&buf_cpu1[0], ptr1,
buf_cpu1.size() * sizeof(Torus), stream, gpu_index);
cuda_memcpy_async_to_cpu((void *)&buf_cpu2[0], ptr2,
buf_cpu2.size() * sizeof(Torus), stream, gpu_index);
cuda_synchronize_device(gpu_index);
std::vector<uint32_t> non_matching_indexes;
for (int i = 0; i < buf_cpu1.size(); ++i) {
if (buf_cpu1[i] != buf_cpu2[i]) {
non_matching_indexes.push_back(i);
}
}
if (!non_matching_indexes.empty()) {
std::stringstream ss;
for (int i = 0; i < std::min(non_matching_indexes.size(), (size_t)10);
++i) {
ss << " difference at " << non_matching_indexes[i] << ": "
<< buf_cpu1[non_matching_indexes[i]] << " vs "
<< buf_cpu2[non_matching_indexes[i]] << " at index "
<< non_matching_indexes[i] << "\n";
}
GPU_ASSERT(non_matching_indexes.empty(),
"Correctness error for matrices %d x %d: \n%s", row_size,
col_size, ss.str().c_str());
}
}
#endif

View File

@@ -79,7 +79,8 @@ __host__ void host_expand_without_verification(
streams.get_ith(0), ksed_small_to_big_expanded_lwes,
lwe_trivial_indexes_vec[0], expanded_lwes, lwe_trivial_indexes_vec[0],
casting_keys, casting_input_dimension, casting_output_dimension,
casting_ks_base_log, casting_ks_level, num_lwes);
casting_ks_base_log, casting_ks_level, num_lwes,
lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
// In this case, the next keyswitch will use the compute ksk
ksks = compute_ksks;

View File

@@ -50,8 +50,8 @@ void keyswitch_setup(cudaStream_t stream, uint32_t gpu_index, Seed *seed,
uint64_t **d_lwe_ct_in_array,
uint64_t **d_lwe_input_indexes,
uint64_t **d_lwe_ct_out_array,
uint64_t **d_lwe_output_indexes, int input_lwe_dimension,
int output_lwe_dimension,
uint64_t **d_lwe_output_indexes, void **ks_tmp_buffer,
int input_lwe_dimension, int output_lwe_dimension,
DynamicDistribution lwe_noise_distribution,
int ksk_base_log, int ksk_level, int message_modulus,
int carry_modulus, int *payload_modulus, uint64_t *delta,
@@ -62,7 +62,7 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index,
uint64_t *d_lwe_ct_in_array,
uint64_t *lwe_input_indexes,
uint64_t *d_lwe_ct_out_array,
uint64_t *lwe_output_indexes);
uint64_t *lwe_output_indexes, void **ks_tmp_buffer);
void fft_setup(cudaStream_t stream, uint32_t gpu_index, double **poly1,
double **poly2, double2 **h_cpoly1, double2 **h_cpoly2,

View File

@@ -260,7 +260,7 @@ void keyswitch_setup(
cudaStream_t stream, uint32_t gpu_index, Seed *seed, uint64_t **lwe_sk_in_array,
uint64_t **lwe_sk_out_array, uint64_t **d_ksk_array, uint64_t **plaintexts,
uint64_t **d_lwe_ct_in_array, uint64_t **d_lwe_input_indexes,
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes,
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes, void** ks_tmp_buffer,
int input_lwe_dimension, int output_lwe_dimension,
DynamicDistribution lwe_noise_distribution, int ksk_base_log, int ksk_level,
int message_modulus, int carry_modulus, int *payload_modulus,
@@ -295,6 +295,11 @@ void keyswitch_setup(
uint64_t *lwe_ct_in_array =
(uint64_t *)malloc((input_lwe_dimension + 1) * number_of_inputs *
repetitions * samples * sizeof(uint64_t));
scratch_cuda_keyswitch_gemm_64(
(void*)stream, gpu_index, ks_tmp_buffer,
input_lwe_dimension, output_lwe_dimension, number_of_inputs * repetitions * samples, true);
// Create the input/output ciphertexts
for (int r = 0; r < repetitions; r++) {
uint64_t *lwe_sk_in =
@@ -343,7 +348,7 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index, uint64_t *lwe_s
uint64_t *plaintexts, uint64_t *d_lwe_ct_in_array,
uint64_t *d_lwe_input_indexes,
uint64_t *d_lwe_ct_out_array,
uint64_t *d_lwe_output_indexes) {
uint64_t *d_lwe_output_indexes, void** ks_tmp_buffer) {
cuda_synchronize_stream(stream, gpu_index);
free(lwe_sk_in_array);
@@ -355,8 +360,14 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index, uint64_t *lwe_s
cuda_drop_async(d_lwe_ct_out_array, stream, gpu_index);
cuda_drop_async(d_lwe_input_indexes, stream, gpu_index);
cuda_drop_async(d_lwe_output_indexes, stream, gpu_index);
cleanup_cuda_keyswitch_gemm_64((void*)stream, gpu_index,
ks_tmp_buffer,
true);
cuda_synchronize_stream(stream, gpu_index);
cuda_destroy_stream(stream, gpu_index);
}
void fft_setup(cudaStream_t stream, uint32_t gpu_index, double **_poly1, double **_poly2,

View File

@@ -45,6 +45,7 @@ protected:
uint64_t *lwe_out_ct;
uint64_t *lwe_input_indexes;
uint64_t *lwe_output_indexes;
void *ks_tmp_buffer;
// Data stays at gpu 0
uint32_t gpu_index = 0;
@@ -83,7 +84,7 @@ public:
keyswitch_setup(streams[0], gpu_index, &seed, &lwe_sk_in_array,
&lwe_sk_out_array, &d_ksk_array, &plaintexts,
&d_lwe_ct_in_array, &lwe_input_indexes, &d_lwe_ct_out_array,
&lwe_output_indexes, input_lwe_dimension,
&lwe_output_indexes, &ks_tmp_buffer, input_lwe_dimension,
output_lwe_dimension, noise_distribution, ksk_base_log,
ksk_level, message_modulus, carry_modulus, &payload_modulus,
&delta, number_of_inputs, REPETITIONS, SAMPLES);
@@ -94,7 +95,7 @@ public:
keyswitch_teardown(streams[0], gpu_index, lwe_sk_in_array, lwe_sk_out_array,
d_ksk_array, plaintexts, d_lwe_ct_in_array,
lwe_input_indexes, d_lwe_ct_out_array,
lwe_output_indexes);
lwe_output_indexes, &ks_tmp_buffer);
if (active_gpu_count > 1) {
for (uint gpu_i = 1; gpu_i < active_gpu_count; gpu_i++) {
cuda_destroy_stream(streams[gpu_i], gpu_i);
@@ -136,10 +137,11 @@ TEST_P(KeyswitchMultiGPUTestPrimitives_u64, keyswitch) {
d_lwe_ct_out_array + (ptrdiff_t)(output_lwe_start_index);
// Execute keyswitch
cuda_keyswitch_lwe_ciphertext_vector_64(
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
streams[gpu_i], gpu_i, d_lwe_ct_out, lwe_output_indexes,
d_lwe_ct_in_slice, lwe_input_indexes, d_ksk, input_lwe_dimension,
output_lwe_dimension, ksk_base_log, ksk_level, num_inputs);
output_lwe_dimension, ksk_base_log, ksk_level, num_inputs,
ks_tmp_buffer, false);
}
for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) {
cuda_synchronize_stream(streams[gpu_i], gpu_i);
@@ -195,6 +197,7 @@ protected:
uint64_t *lwe_out_ct;
uint64_t *lwe_input_indexes;
uint64_t *lwe_output_indexes;
void *ks_tmp_buffer;
public:
// Test arithmetic functions
@@ -217,7 +220,7 @@ public:
keyswitch_setup(stream, gpu_index, &seed, &lwe_sk_in_array,
&lwe_sk_out_array, &d_ksk_array, &plaintexts,
&d_lwe_ct_in_array, &lwe_input_indexes, &d_lwe_ct_out_array,
&lwe_output_indexes, input_lwe_dimension,
&lwe_output_indexes, &ks_tmp_buffer, input_lwe_dimension,
output_lwe_dimension, noise_distribution, ksk_base_log,
ksk_level, message_modulus, carry_modulus, &payload_modulus,
&delta, number_of_inputs, REPETITIONS, SAMPLES);
@@ -227,7 +230,7 @@ public:
keyswitch_teardown(stream, gpu_index, lwe_sk_in_array, lwe_sk_out_array,
d_ksk_array, plaintexts, d_lwe_ct_in_array,
lwe_input_indexes, d_lwe_ct_out_array,
lwe_output_indexes);
lwe_output_indexes, &ks_tmp_buffer);
}
};
@@ -245,11 +248,12 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
(ptrdiff_t)((r * SAMPLES * number_of_inputs + s * number_of_inputs) *
(input_lwe_dimension + 1));
// Execute keyswitch
cuda_keyswitch_lwe_ciphertext_vector_64(
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
stream, gpu_index, (void *)d_lwe_ct_out_array,
(void *)lwe_output_indexes, (void *)d_lwe_ct_in,
(void *)lwe_input_indexes, (void *)d_ksk, input_lwe_dimension,
output_lwe_dimension, ksk_base_log, ksk_level, number_of_inputs);
output_lwe_dimension, ksk_base_log, ksk_level, number_of_inputs,
ks_tmp_buffer, false);
// Copy result back
cuda_memcpy_async_to_cpu(lwe_out_ct, d_lwe_ct_out_array,

View File

@@ -2519,6 +2519,24 @@ unsafe extern "C" {
num_samples: u32,
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
stream: *mut ffi::c_void,
gpu_index: u32,
lwe_array_out: *mut ffi::c_void,
lwe_output_indexes: *const ffi::c_void,
lwe_array_in: *const ffi::c_void,
lwe_input_indexes: *const ffi::c_void,
ksk: *const ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
base_log: u32,
level_count: u32,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indexes: bool,
);
}
unsafe extern "C" {
pub fn cuda_keyswitch_lwe_ciphertext_vector_64(
stream: *mut ffi::c_void,
@@ -2547,6 +2565,25 @@ unsafe extern "C" {
allocate_gpu_memory: bool,
) -> u64;
}
unsafe extern "C" {
pub fn scratch_cuda_keyswitch_gemm_64(
stream: *mut ffi::c_void,
gpu_index: u32,
ks_tmp_memory: *mut *mut ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
num_lwes: u32,
allocate_gpu_memory: bool,
) -> u64;
}
unsafe extern "C" {
pub fn cleanup_cuda_keyswitch_gemm_64(
stream: *mut ffi::c_void,
gpu_index: u32,
ks_tmp_memory: *mut *mut ffi::c_void,
allocate_gpu_memory: bool,
);
}
unsafe extern "C" {
pub fn cuda_packing_keyswitch_lwe_list_to_glwe_64(
stream: *mut ffi::c_void,

View File

@@ -386,7 +386,7 @@ mod cuda {
.keyswitch_key(ksk_big_to_small)
.build();
let bench_id;
let mut bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
@@ -423,22 +423,47 @@ mod cuda {
&mut output_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
black_box(&mut ct_gpu);
})
});
}
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&bench_id,
*params,
name,
"ks",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
BenchmarkType::Throughput => {
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);
let gpu_count = get_number_of_gpus() as usize;
bench_id = format!("{bench_name}::throughput::{name}");
let blocks: usize = 1;
let elements = throughput_num_threads(blocks, 1);
let elements_per_stream = elements as usize / gpu_count;
bench_group.throughput(Throughput::Elements(elements));
for uses_gemm_ks in [false, true] {
for uses_simple_indices in [false, true] {
let indices_str = if uses_simple_indices {
"simple"
} else {
"complex"
};
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
bench_id = format!(
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
);
let blocks: usize = 256;
let elements = gpu_count * blocks;
let elements_per_stream = elements / gpu_count;
bench_group.throughput(Throughput::Elements(elements as u64));
bench_group.sample_size(50);
bench_group.bench_function(&bench_id, |b| {
let setup_encrypted_values = || {
@@ -491,9 +516,16 @@ mod cuda {
})
.collect::<Vec<_>>();
let h_indexes = (0..(elements / gpu_count as u64))
.map(CastFrom::cast_from)
let indexes_range: Vec<u64> = if uses_simple_indices {
(0..(elements / gpu_count) as u64).collect()
} else {
(0..(elements / gpu_count) as u64).rev().collect()
};
let h_indexes = indexes_range
.iter()
.map(|v| CastFrom::cast_from(*v))
.collect::<Vec<_>>();
let cuda_indexes_vec = (0..gpu_count)
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
.collect::<Vec<_>>();
@@ -504,28 +536,35 @@ mod cuda {
b.iter_batched(
setup_encrypted_values,
|(input_cts, mut output_cts, cuda_indexes_vec, local_streams)| {
|(
input_cts,
mut output_cts,
cuda_indexes_vec,
local_streams,
)| {
(0..gpu_count)
.into_par_iter()
.zip(input_cts.par_iter())
.zip(output_cts.par_iter_mut())
.zip(local_streams.par_iter())
.for_each(|(((i, input_ct), output_ct), local_stream)| {
.for_each(
|(((i, input_ct), output_ct), local_stream)| {
cuda_keyswitch_lwe_ciphertext(
gpu_keys_vec[i].ksk.as_ref().unwrap(),
input_ct,
output_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
uses_simple_indices,
local_stream,
uses_gemm_ks,
);
})
},
)
},
criterion::BatchSize::SmallInput,
)
});
}
};
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
@@ -539,6 +578,10 @@ mod cuda {
);
}
}
}
};
}
}
fn cuda_packing_keyswitch<
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<u64> + Serialize,

View File

@@ -630,7 +630,9 @@ mod cuda {
&mut output_ks_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
cuda_programmable_bootstrap_lwe_ciphertext(
&output_ks_ct_gpu,
@@ -782,7 +784,9 @@ mod cuda {
output_ks_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
true,
local_stream,
false,
);
cuda_programmable_bootstrap_lwe_ciphertext(
output_ks_ct,
@@ -937,7 +941,9 @@ mod cuda {
&mut output_ks_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
&output_ks_ct_gpu,
@@ -1088,7 +1094,9 @@ mod cuda {
output_ks_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
true,
local_stream,
false,
);
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
output_ks_ct,

View File

@@ -1,20 +1,28 @@
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{keyswitch_async, CudaStreams};
use crate::core_crypto::gpu::{
keyswitch_async, keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
};
use crate::core_crypto::prelude::UnsignedInteger;
use std::cmp::min;
use tfhe_cuda_backend::bindings::cleanup_cuda_keyswitch_gemm_64;
use tfhe_cuda_backend::ffi;
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
#[allow(clippy::too_many_arguments)]
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
input_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
uses_trivial_indices: bool,
streams: &CudaStreams,
use_gemm_ks: bool,
) where
Scalar: UnsignedInteger,
{
@@ -70,6 +78,50 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
output_indexes.gpu_index(0).get(),
);
let mut ks_tmp_buffer: *mut ffi::c_void = std::ptr::null_mut();
let num_lwes_to_ks = min(
input_indexes.len,
input_lwe_ciphertext.lwe_ciphertext_count().0,
);
assert_eq!(
input_indexes.len, output_indexes.len,
"The number of input and output indexes must be the same for LWE keyswitch"
);
if use_gemm_ks {
cuda_scratch_keyswitch_lwe_ciphertext_async::<Scalar>(
streams,
std::ptr::addr_of_mut!(ks_tmp_buffer),
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension().0 as u32,
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension().0 as u32,
num_lwes_to_ks as u32,
true,
);
keyswitch_async_gemm(
streams,
&mut output_lwe_ciphertext.0.d_vec,
output_indexes,
&input_lwe_ciphertext.0.d_vec,
input_indexes,
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension(),
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension(),
&lwe_keyswitch_key.d_vec,
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
num_lwes_to_ks as u32,
ks_tmp_buffer,
uses_trivial_indices,
);
cleanup_cuda_keyswitch_async::<Scalar>(
streams,
std::ptr::addr_of_mut!(ks_tmp_buffer),
true,
);
} else {
keyswitch_async(
streams,
&mut output_lwe_ciphertext.0.d_vec,
@@ -81,17 +133,21 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
&lwe_keyswitch_key.d_vec,
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
input_lwe_ciphertext.lwe_ciphertext_count().0 as u32,
num_lwes_to_ks as u32,
);
}
}
#[allow(clippy::too_many_arguments)]
pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
input_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
uses_trivial_indices: bool,
streams: &CudaStreams,
use_gemm_ks: bool,
) where
Scalar: UnsignedInteger,
{
@@ -102,8 +158,52 @@ pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
output_lwe_ciphertext,
input_indexes,
output_indexes,
uses_trivial_indices,
streams,
use_gemm_ks,
);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronized
pub unsafe fn cuda_scratch_keyswitch_lwe_ciphertext_async<Scalar>(
streams: &CudaStreams,
ks_tmp_buffer: *mut *mut ffi::c_void,
lwe_dimension_in: u32,
lwe_dimension_out: u32,
num_lwes: u32,
allocate_gpu_memory: bool,
) where
Scalar: UnsignedInteger,
{
scratch_cuda_keyswitch_gemm_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
ks_tmp_buffer,
lwe_dimension_in,
lwe_dimension_out,
num_lwes,
allocate_gpu_memory,
);
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronized
pub unsafe fn cleanup_cuda_keyswitch_async<Scalar>(
streams: &CudaStreams,
ks_tmp_buffer: *mut *mut ffi::c_void,
allocate_gpu_memory: bool,
) {
cleanup_cuda_keyswitch_gemm_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
ks_tmp_buffer,
allocate_gpu_memory,
);
}

View File

@@ -5,6 +5,8 @@ use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
use crate::core_crypto::gpu::{cuda_keyswitch_lwe_ciphertext, CudaStreams};
use crate::core_crypto::prelude::misc::check_encrypted_content_respects_mod;
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
params: ClassicTestParams<Scalar>,
@@ -19,6 +21,57 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
let ks_decomp_base_log = params.ks_base_log;
let ks_decomp_level_count = params.ks_level;
base_lwe_encrypt_ks_decrypt_custom_mod(
lwe_dimension,
lwe_noise_distribution,
ciphertext_modulus,
message_modulus_log,
encoding_with_padding,
glwe_dimension,
polynomial_size,
ks_decomp_base_log,
ks_decomp_level_count,
);
}
fn lwe_encrypt_ks_decrypt_custom_mod_mb<Scalar: UnsignedTorus + CastFrom<usize>>(
params: &MultiBitTestParams<Scalar>,
) {
let lwe_dimension = params.input_lwe_dimension;
let lwe_noise_distribution = DynamicDistribution::new_gaussian_from_std_dev(StandardDev(0f64));
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;
let ks_decomp_base_log = params.decomp_base_log;
let ks_decomp_level_count = params.decomp_level_count;
base_lwe_encrypt_ks_decrypt_custom_mod(
lwe_dimension,
lwe_noise_distribution,
ciphertext_modulus,
message_modulus_log,
encoding_with_padding,
glwe_dimension,
polynomial_size,
ks_decomp_base_log,
ks_decomp_level_count,
);
}
#[allow(clippy::too_many_arguments)]
fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
lwe_dimension: LweDimension,
lwe_noise_distribution: DynamicDistribution<Scalar>,
ciphertext_modulus: CiphertextModulus<Scalar>,
message_modulus_log: MessageModulusLog,
encoding_with_padding: Scalar,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
ks_decomp_base_log: DecompositionBaseLog,
ks_decomp_level_count: DecompositionLevelCount,
) {
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
let mut rsc = TestResources::new();
@@ -61,65 +114,144 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
let plaintext = Plaintext(msg * delta);
for test_idx in 0..NB_TESTS {
let num_blocks = test_idx * test_idx * 3 + 1;
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
plaintext,
lwe_noise_distribution,
let plaintext_list = (0..num_blocks)
.map(|i| (Scalar::cast_from(i) % msg_modulus) * delta)
.collect_vec();
let plaintext_list = PlaintextList::from_container(plaintext_list);
let mut input_ct_list = LweCiphertextList::new(
Scalar::ZERO,
big_lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(num_blocks),
ciphertext_modulus,
);
encrypt_lwe_ciphertext_list(
&big_lwe_sk,
&mut input_ct_list,
&plaintext_list,
lwe_noise_distribution,
&mut rsc.encryption_random_generator,
);
let input_ct_list_gpu =
CudaLweCiphertextList::from_lwe_ciphertext_list(&input_ct_list, &stream);
let output_ct_list = LweCiphertextList::new(
Scalar::ZERO,
ksk_big_to_small.output_key_lwe_dimension().to_lwe_size(),
input_ct_list.lwe_ciphertext_count(),
ksk_big_to_small.ciphertext_modulus(),
);
let mut output_ct_list_gpu =
CudaLweCiphertextList::from_lwe_ciphertext_list(&output_ct_list, &stream);
let mut output_ct_list_gpu_gemm =
CudaLweCiphertextList::from_lwe_ciphertext_list(&output_ct_list, &stream);
assert!(check_encrypted_content_respects_mod(
&ct,
&input_ct_list,
ciphertext_modulus
));
let d_ct = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &stream);
let mut d_output_ct = CudaLweCiphertextList::new(
ksk_big_to_small.output_key_lwe_dimension(),
LweCiphertextCount(1),
ciphertext_modulus,
&stream,
);
let num_blocks = d_ct.0.lwe_ciphertext_count.0;
let use_trivial_indexes = test_idx % 2 == 0;
let num_blocks_to_ks = if use_trivial_indexes {
if test_idx % 4 == 0 {
num_blocks
} else {
num_blocks / 2
}
} else {
num_blocks
};
let lwe_indexes_usize = (0..num_blocks).collect_vec();
let lwe_indexes = lwe_indexes_usize
let mut lwe_indexes = lwe_indexes_usize.clone();
let mut lwe_indexes_out = lwe_indexes.clone();
if !use_trivial_indexes {
lwe_indexes.shuffle(&mut thread_rng());
lwe_indexes_out.shuffle(&mut thread_rng());
}
let h_lwe_indexes: Vec<Scalar> = lwe_indexes
.iter()
.take(num_blocks_to_ks)
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let h_lwe_indexes_out: Vec<Scalar> = lwe_indexes_out
.iter()
.take(num_blocks_to_ks)
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let mut d_input_indexes =
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
unsafe { CudaVec::<Scalar>::new_async(num_blocks_to_ks, &stream, 0) };
let mut d_output_indexes =
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
unsafe { d_input_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
unsafe { d_output_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
unsafe { CudaVec::<Scalar>::new_async(num_blocks_to_ks, &stream, 0) };
unsafe { d_input_indexes.copy_from_cpu_async(&h_lwe_indexes, &stream, 0) };
unsafe { d_output_indexes.copy_from_cpu_async(&h_lwe_indexes_out, &stream, 0) };
cuda_keyswitch_lwe_ciphertext(
&d_ksk_big_to_small,
&d_ct,
&mut d_output_ct,
&input_ct_list_gpu,
&mut output_ct_list_gpu,
&d_input_indexes,
&d_output_indexes,
use_trivial_indexes,
&stream,
false,
);
let output_ct = d_output_ct.into_lwe_ciphertext(&stream);
cuda_keyswitch_lwe_ciphertext(
&d_ksk_big_to_small,
&input_ct_list_gpu,
&mut output_ct_list_gpu_gemm,
&d_input_indexes,
&d_output_indexes,
use_trivial_indexes,
&stream,
true,
);
// Fill in the expected output: only the LWEs corresponding to output indices
// will be non-zero. The test checks that the others remain 0
let mut ref_vec = vec![Scalar::ZERO; num_blocks];
for i in 0..num_blocks_to_ks {
ref_vec[lwe_indexes_out[i]] =
round_decode(*plaintext_list.get(lwe_indexes[i]).0, delta);
}
assert_eq!(output_ct_list_gpu.lwe_ciphertext_count().0, num_blocks);
// The output has `n_blocks` LWEs but only some are actually set - those
// that correspond to output indices. We loop over all LWEs in the output buffer
let output_ct_list_cpu = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
output_ct_list_gpu
.to_lwe_ciphertext_list(&stream)
.iter()
.zip(0..num_blocks)
.for_each(|(lwe_ct_out, i)| {
assert!(check_encrypted_content_respects_mod(
&output_ct,
&lwe_ct_out,
ciphertext_modulus
));
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
let lwe_ct_out_gemm = output_ct_list_cpu.get(i);
let decrypted_gemm = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out_gemm);
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
let decoded_gemm = round_decode(decrypted_gemm.0, delta) % msg_modulus;
assert_eq!(msg, decoded);
assert_eq!(ref_vec[i], decoded);
assert_eq!(ref_vec[i], decoded_gemm);
});
}
}
}
create_gpu_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod);
create_gpu_multi_bit_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_mb);

View File

@@ -16,6 +16,7 @@ use std::any::{Any, TypeId};
use std::ffi::c_void;
use tfhe_cuda_backend::bindings::*;
use tfhe_cuda_backend::cuda_bind::*;
use tfhe_cuda_backend::ffi;
pub struct CudaStreams {
pub ptr: Vec<*mut c_void>,
@@ -477,13 +478,53 @@ pub fn get_programmable_bootstrap_multi_bit_size_on_gpu(
size_tracker
}
/// Keyswitch on a vector of LWE ciphertexts
/// Keyswitch on a vector of LWE ciphertexts using the GEMM batch KS approach
///
/// # Safety
///
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
streams: &CudaStreams,
lwe_array_out: &mut CudaVec<T>,
lwe_out_indexes: &CudaVec<T>,
lwe_array_in: &CudaVec<T>,
lwe_in_indexes: &CudaVec<T>,
input_lwe_dimension: LweDimension,
output_lwe_dimension: LweDimension,
keyswitch_key: &CudaVec<T>,
base_log: DecompositionBaseLog,
l_gadget: DecompositionLevelCount,
num_samples: u32,
ks_tmp_buffer: *const ffi::c_void,
uses_trivial_indices: bool,
) {
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
streams.ptr[0],
streams.gpu_indexes[0].get(),
lwe_array_out.as_mut_c_ptr(0),
lwe_out_indexes.as_c_ptr(0),
lwe_array_in.as_c_ptr(0),
lwe_in_indexes.as_c_ptr(0),
keyswitch_key.as_c_ptr(0),
input_lwe_dimension.0 as u32,
output_lwe_dimension.0 as u32,
base_log.0 as u32,
l_gadget.0 as u32,
num_samples,
ks_tmp_buffer,
uses_trivial_indices,
);
}
/// Keyswitch on a vector of LWE ciphertexts. Better for small batches of LWEs
/// (up to 128 LWEs on H100, up to 64 on L40, up to 16 on 4090)
/// # Safety
///
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
/// required
#[allow(clippy::too_many_arguments)]
pub unsafe fn keyswitch_async<T: UnsignedInteger>(
streams: &CudaStreams,
lwe_array_out: &mut CudaVec<T>,
@@ -512,7 +553,6 @@ pub unsafe fn keyswitch_async<T: UnsignedInteger>(
num_samples,
);
}
/// Convert keyswitch key
///
/// # Safety