mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
chore(gpu): bench KS latency batches
This commit is contained in:
committed by
Andrei Stoian
parent
d6a0a366b9
commit
e2063c8ef4
2
Makefile
2
Makefile
@@ -776,7 +776,7 @@ build_debug_integer_short_run_gpu: install_cargo_nextest
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile debug_lto_off \
|
||||
--features=integer,gpu-debug-fake-multi-gpu -p tfhe -- integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random --list
|
||||
@echo "To debug fake-multi-gpu short run tests run:"
|
||||
@echo "TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
|
||||
@echo "TFHE_RS_LONGRUN_TESTS_SEED=<SEED_FROM_CI> TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
|
||||
@echo "Where <executable> = the one printed in the () in the 'Running unittests src/lib.rs ()' line above"
|
||||
|
||||
.PHONY: test_integer_compression
|
||||
|
||||
@@ -37,6 +37,7 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
auto active_streams_and_lut = streams.active_gpu_subset(
|
||||
SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism);
|
||||
this->and_lut->broadcast_lut(active_streams_and_lut);
|
||||
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
this->flush_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, AES_STATE_BITS * num_aes_inputs,
|
||||
@@ -52,6 +53,7 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
auto active_streams_flush_lut =
|
||||
streams.active_gpu_subset(AES_STATE_BITS * num_aes_inputs);
|
||||
this->flush_lut->broadcast_lut(active_streams_flush_lut);
|
||||
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
this->carry_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, num_aes_inputs, allocate_gpu_memory, size_tracker);
|
||||
@@ -65,6 +67,7 @@ template <typename Torus> struct int_aes_lut_buffers {
|
||||
params.carry_modulus, carry_lambda, allocate_gpu_memory);
|
||||
auto active_streams_carry_lut = streams.active_gpu_subset(num_aes_inputs);
|
||||
this->carry_lut->broadcast_lut(active_streams_carry_lut);
|
||||
this->carry_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "crypto/keyswitch.cuh"
|
||||
|
||||
class NoiseLevel {
|
||||
public:
|
||||
// Constants equivalent to the Rust code
|
||||
@@ -336,6 +338,9 @@ struct int_radix_lut_custom_input_output {
|
||||
std::vector<InputTorus *> lwe_after_ks_vec;
|
||||
std::vector<OutputTorus *> lwe_after_pbs_vec;
|
||||
std::vector<InputTorus *> lwe_trivial_indexes_vec;
|
||||
std::vector<ks_mem<InputTorus> *>
|
||||
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
|
||||
|
||||
std::vector<InputTorus *> lwe_aligned_vec;
|
||||
|
||||
bool gpu_memory_allocated;
|
||||
@@ -443,6 +448,30 @@ struct int_radix_lut_custom_input_output {
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void setup_gemm_batch_ks_temp_buffers(uint64_t &size_tracker) {
|
||||
|
||||
auto inputs_on_gpu =
|
||||
std::min((int)num_input_blocks,
|
||||
std::max(THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(num_input_blocks, 0,
|
||||
active_streams.count())));
|
||||
|
||||
if (inputs_on_gpu >= get_threshold_ks_gemm()) {
|
||||
for (auto i = 0; i < active_streams.count(); ++i) {
|
||||
ks_mem<InputTorus> *ks_buffer;
|
||||
uint64_t sub_size_tracker = scratch_cuda_keyswitch<InputTorus>(
|
||||
active_streams.stream(i), active_streams.gpu_index(i), &ks_buffer,
|
||||
input_big_lwe_dimension, params.small_lwe_dimension, num_blocks,
|
||||
gpu_memory_allocated);
|
||||
|
||||
if (i == 0) {
|
||||
size_tracker += sub_size_tracker;
|
||||
}
|
||||
ks_tmp_buf_vec.push_back(ks_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void setup_mem_reuse(uint32_t num_radix_blocks,
|
||||
int_radix_lut_custom_input_output *base_lut_object) {
|
||||
// base lut object should have bigger or equal memory than current one
|
||||
@@ -461,6 +490,8 @@ struct int_radix_lut_custom_input_output {
|
||||
lwe_after_pbs_vec = base_lut_object->lwe_after_pbs_vec;
|
||||
lwe_trivial_indexes_vec = base_lut_object->lwe_trivial_indexes_vec;
|
||||
|
||||
ks_tmp_buf_vec = base_lut_object->ks_tmp_buf_vec;
|
||||
|
||||
mem_reuse = true;
|
||||
}
|
||||
|
||||
@@ -865,6 +896,13 @@ struct int_radix_lut_custom_input_output {
|
||||
}
|
||||
lwe_aligned_vec.clear();
|
||||
}
|
||||
|
||||
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
|
||||
cleanup_cuda_keyswitch(active_streams.stream(i),
|
||||
active_streams.gpu_index(i), ks_tmp_buf_vec[i],
|
||||
gpu_memory_allocated);
|
||||
}
|
||||
ks_tmp_buf_vec.clear();
|
||||
}
|
||||
free(h_lut_indexes);
|
||||
free(degrees);
|
||||
|
||||
@@ -15,6 +15,10 @@ template <typename Torus> struct int_rerand_mem {
|
||||
|
||||
bool gpu_memory_allocated;
|
||||
|
||||
std::vector<ks_mem<Torus> *>
|
||||
ks_tmp_buf_vec; // not allocated, ReRand not using GEMM KS for now
|
||||
// kept empty to pass to the KS function indicating GEMM KS disabled
|
||||
|
||||
expand_job<Torus> *d_expand_jobs;
|
||||
expand_job<Torus> *h_expand_jobs;
|
||||
|
||||
@@ -72,6 +76,13 @@ template <typename Torus> struct int_rerand_mem {
|
||||
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
|
||||
streams.gpu_index(0),
|
||||
gpu_memory_allocated);
|
||||
|
||||
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
|
||||
cleanup_cuda_keyswitch(streams.stream(i), streams.gpu_index(i),
|
||||
ks_tmp_buf_vec[i], gpu_memory_allocated);
|
||||
}
|
||||
ks_tmp_buf_vec.clear();
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
free(h_expand_jobs);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,13 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples);
|
||||
|
||||
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *lwe_output_indexes, void const *lwe_array_in,
|
||||
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *lwe_output_indexes, void const *lwe_array_in,
|
||||
@@ -24,6 +31,17 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t num_lwes, bool allocate_gpu_memory);
|
||||
|
||||
uint64_t scratch_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
|
||||
void **ks_tmp_memory,
|
||||
uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out,
|
||||
uint32_t num_lwes,
|
||||
bool allocate_gpu_memory);
|
||||
|
||||
void cleanup_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
|
||||
void **ks_tmp_memory,
|
||||
bool allocate_gpu_memory);
|
||||
|
||||
void cuda_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
void *stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
|
||||
|
||||
@@ -9,14 +9,16 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
|
||||
host_keyswitch_lwe_ciphertext_vector<uint32_t>(
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
void *ksk_tmp_buffer, bool uses_trivial_indices) {
|
||||
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
static_cast<uint32_t *>(lwe_array_out),
|
||||
static_cast<uint32_t *>(lwe_output_indexes),
|
||||
static_cast<uint32_t *>(lwe_array_in),
|
||||
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
|
||||
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
|
||||
}
|
||||
|
||||
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
|
||||
@@ -35,6 +37,26 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
* This function calls a wrapper to a device kernel that performs the keyswitch
|
||||
* - num_samples blocks of threads are launched
|
||||
*/
|
||||
void cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *lwe_output_indexes, void const *lwe_array_in,
|
||||
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, const void *ks_tmp_buffer,
|
||||
bool uses_trivial_indices) {
|
||||
|
||||
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<const uint64_t *>(lwe_output_indexes),
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
static_cast<const uint64_t *>(lwe_input_indexes),
|
||||
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
|
||||
base_log, level_count, num_samples,
|
||||
static_cast<const ks_mem<uint64_t> *>(ks_tmp_buffer)->d_buffer,
|
||||
uses_trivial_indices);
|
||||
}
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *stream, uint32_t gpu_index, void *lwe_array_out,
|
||||
void const *lwe_output_indexes, void const *lwe_array_in,
|
||||
@@ -44,10 +66,10 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
host_keyswitch_lwe_ciphertext_vector<uint64_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<const uint64_t *>(lwe_output_indexes),
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
static_cast<const uint64_t *>(lwe_input_indexes),
|
||||
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
|
||||
static_cast<uint64_t const *>(lwe_output_indexes),
|
||||
static_cast<uint64_t const *>(lwe_array_in),
|
||||
static_cast<uint64_t const *>(lwe_input_indexes),
|
||||
static_cast<uint64_t const *>(ksk), lwe_dimension_in, lwe_dimension_out,
|
||||
base_log, level_count, num_samples);
|
||||
}
|
||||
|
||||
@@ -60,6 +82,27 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
|
||||
void **ks_tmp_buffer,
|
||||
uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out,
|
||||
uint32_t num_lwes,
|
||||
bool allocate_gpu_memory) {
|
||||
return scratch_cuda_keyswitch<uint64_t>(
|
||||
static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(ks_mem<uint64_t> **)ks_tmp_buffer, lwe_dimension_in, lwe_dimension_out,
|
||||
num_lwes, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cleanup_cuda_keyswitch_gemm_64(void *stream, uint32_t gpu_index,
|
||||
void **ks_tmp_buffer,
|
||||
bool allocate_gpu_memory) {
|
||||
cleanup_cuda_keyswitch<uint64_t>(static_cast<cudaStream_t>(stream), gpu_index,
|
||||
(ks_mem<uint64_t> *)*ks_tmp_buffer,
|
||||
allocate_gpu_memory);
|
||||
*ks_tmp_buffer = nullptr;
|
||||
}
|
||||
|
||||
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
|
||||
* ciphertexts.
|
||||
*/
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef CNCRT_KS_CUH
|
||||
#define CNCRT_KS_CUH
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "device.h"
|
||||
#include "gadget.cuh"
|
||||
#include "helper_multi_gpu.h"
|
||||
@@ -10,8 +12,30 @@
|
||||
#include "utils/helper.cuh"
|
||||
#include "utils/kernel_dimensions.cuh"
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
const int BLOCK_SIZE_DECOMP = 8;
|
||||
const int BLOCK_SIZE_GEMM_KS = 36;
|
||||
const int THREADS_GEMM_KS = 6;
|
||||
|
||||
inline uint64_t get_threshold_ks_gemm() { return 128; }
|
||||
|
||||
template <typename Torus> struct ks_mem {
|
||||
Torus *d_buffer;
|
||||
uint64_t num_lwes;
|
||||
uint32_t lwe_dimension;
|
||||
};
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_keyswitch_size(uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out,
|
||||
uint32_t num_lwes) {
|
||||
GPU_ASSERT(lwe_dimension_in >= lwe_dimension_out,
|
||||
"Trying to allocate KS temp buffer for invalid LWE dimensions");
|
||||
return (uint64_t)num_lwes * lwe_dimension_in * sizeof(Torus) * 2;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
uint32_t lwe_dimension_out,
|
||||
@@ -22,6 +46,169 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// Initialize decomposition by performing rounding
|
||||
// and decomposing one level of an array of Torus LWEs. Only
|
||||
// decomposes the mask elements of the incoming LWEs.
|
||||
template <typename Torus>
|
||||
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
|
||||
uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
uint32_t level_count) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
||||
return;
|
||||
|
||||
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
|
||||
// we only decompose the mask. Thus the stride for reading
|
||||
// is lwe_dimension + 1, while for writing it is lwe_dimension
|
||||
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
|
||||
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
auto write_state_idx =
|
||||
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
|
||||
Torus a_i = lwe_in[read_val_idx];
|
||||
|
||||
Torus state = init_decomposer_state(a_i, base_log, level_count);
|
||||
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
__syncthreads();
|
||||
lwe_out[write_state_idx] = state;
|
||||
}
|
||||
|
||||
// Decompose an array of LWEs with indirection through lwe_input_indices
|
||||
// The LWE array can contain total_lwe LWEs where total_lwe can be different
|
||||
// from num_lwe. The maximum index should be <= total_lwe. num_lwe is the number
|
||||
// of LWEs to decompose The output buffer should have space for num_lwe LWEs.
|
||||
// These will be sorted according to the input indices.
|
||||
template <typename Torus>
|
||||
__global__ void decompose_vectorize_init_with_indices(
|
||||
Torus const *lwe_in, const Torus *__restrict__ lwe_input_indices,
|
||||
Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log,
|
||||
uint32_t level_count) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
||||
return;
|
||||
|
||||
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
|
||||
// we only decompose the mask. Thus the stride for reading
|
||||
// is lwe_dimension + 1, while for writing it is lwe_dimension
|
||||
auto read_val_idx =
|
||||
lwe_input_indices[lwe_idx] * (lwe_dimension + 1) + lwe_sample_idx;
|
||||
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
auto write_state_idx =
|
||||
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
|
||||
Torus a_i = lwe_in[read_val_idx];
|
||||
|
||||
Torus state = init_decomposer_state(a_i, base_log, level_count);
|
||||
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
__syncthreads();
|
||||
lwe_out[write_state_idx] = state;
|
||||
}
|
||||
|
||||
// Continue decomposition of an array of Torus elements in place. Supposes
|
||||
// that the array contains already decomposed elements and
|
||||
// computes the new decomposed level in place.
|
||||
template <typename Torus>
|
||||
__global__ void
|
||||
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
uint32_t level_count) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
||||
return;
|
||||
|
||||
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
auto state_idx = num_lwe * lwe_dimension + val_idx;
|
||||
|
||||
Torus state = buffer_in[state_idx];
|
||||
__syncthreads();
|
||||
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
|
||||
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
__syncthreads();
|
||||
buffer_in[state_idx] = state;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__global__ void keyswitch_gemm_copy_negated_message_with_indices(
|
||||
const Torus *__restrict__ lwe_in,
|
||||
const Torus *__restrict__ lwe_input_indices, Torus *__restrict__ lwe_out,
|
||||
const Torus *__restrict__ lwe_output_indices,
|
||||
|
||||
uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
|
||||
|
||||
uint32_t lwe_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (lwe_id >= num_lwes)
|
||||
return;
|
||||
|
||||
uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
|
||||
uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
|
||||
|
||||
lwe_out[lwe_out_idx * (lwe_dimension_out + 1) + lwe_dimension_out] =
|
||||
-lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
|
||||
}
|
||||
|
||||
// Finishes the KS computation by negating all elements in the array
|
||||
// using output indices. The array contains -b + SUM(a_i x LWE_i)
|
||||
// and this final step computes b - SUM(a_i x LWE_i)
|
||||
template <typename Torus>
|
||||
__global__ void keyswitch_negate_with_output_indices(
|
||||
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
|
||||
uint32_t lwe_size, uint32_t num_lwe) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
|
||||
return;
|
||||
|
||||
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
|
||||
|
||||
Torus val = buffer_in[val_idx];
|
||||
buffer_in[val_idx] = -val;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__global__ void keyswitch_zero_output_with_output_indices(
|
||||
Torus *buffer_in, const Torus *__restrict__ lwe_output_indices,
|
||||
uint32_t lwe_size, uint32_t num_lwe) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
|
||||
return;
|
||||
|
||||
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
|
||||
|
||||
buffer_in[val_idx] = 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* keyswitch kernel
|
||||
* Each thread handles a piece of the following equation:
|
||||
@@ -142,14 +329,141 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void execute_keyswitch_async(CudaStreams streams,
|
||||
const LweArrayVariant<Torus> &lwe_array_out,
|
||||
const LweArrayVariant<Torus> &lwe_output_indexes,
|
||||
const LweArrayVariant<Torus> &lwe_array_in,
|
||||
const LweArrayVariant<Torus> &lwe_input_indexes,
|
||||
Torus *const *ksks, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
__host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
Torus const *lwe_output_indices, Torus const *lwe_array_in,
|
||||
Torus const *lwe_input_indices, Torus const *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices) {
|
||||
cuda_set_device(gpu_index);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
auto d_mem_0 = fp_tmp_buffer; // keeps decomposed value
|
||||
|
||||
// Set the scratch buffer to 0 as it is used to accumulate
|
||||
// decomposition temporary results
|
||||
if (uses_trivial_indices) {
|
||||
cuda_memset_async(lwe_array_out, 0,
|
||||
num_samples * (lwe_dimension_out + 1) * sizeof(Torus),
|
||||
stream, gpu_index);
|
||||
} else {
|
||||
// gemm to ks the individual LWEs to GLWEs
|
||||
dim3 grid_zero(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
|
||||
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
||||
dim3 threads_zero(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
||||
|
||||
keyswitch_zero_output_with_output_indices<Torus>
|
||||
<<<grid_zero, threads_zero, 0, stream>>>(
|
||||
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
|
||||
num_samples);
|
||||
}
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
dim3 grid_copy(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
||||
dim3 threads_copy(BLOCK_SIZE_DECOMP);
|
||||
|
||||
// lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
|
||||
// lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
|
||||
// and negate
|
||||
keyswitch_gemm_copy_negated_message_with_indices<Torus>
|
||||
<<<grid_copy, threads_copy, 0, stream>>>(
|
||||
lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
|
||||
lwe_dimension_in, num_samples, lwe_dimension_out);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// decompose LWEs
|
||||
// don't decompose LWE body - the LWE has lwe_size + 1 elements. The last
|
||||
// element, the body is ignored by rounding down the number of blocks assuming
|
||||
// here that the LWE dimension is a multiple of the block size
|
||||
dim3 grid_decomp(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP),
|
||||
CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP));
|
||||
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
||||
|
||||
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
|
||||
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
|
||||
// 32, 64, and 128-bit Torus elements
|
||||
// Sanity check: the shared memory size is a constant defined by the algorithm
|
||||
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
|
||||
"GEMM kernel error: shared memory required might be too large");
|
||||
|
||||
auto stride_KSK_buffer = (lwe_dimension_out + 1) * level_count;
|
||||
|
||||
// gemm to ks the individual LWEs to GLWEs
|
||||
dim3 grid_gemm(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_GEMM_KS),
|
||||
CEIL_DIV(num_samples, BLOCK_SIZE_GEMM_KS));
|
||||
dim3 threads_gemm(BLOCK_SIZE_GEMM_KS * THREADS_GEMM_KS);
|
||||
|
||||
// decompose first level (skips the body in the input buffer)
|
||||
decompose_vectorize_init_with_indices<Torus>
|
||||
<<<grid_decomp, threads_decomp, 0, stream>>>(
|
||||
lwe_array_in, lwe_input_indices, fp_tmp_buffer, lwe_dimension_in,
|
||||
num_samples, base_log, level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
if (uses_trivial_indices) {
|
||||
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
||||
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
||||
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1,
|
||||
lwe_output_indices);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
auto ksk_block_size = (lwe_dimension_out + 1);
|
||||
|
||||
for (int li = 1; li < level_count; ++li) {
|
||||
decompose_vectorize_step_inplace<Torus>
|
||||
<<<grid_decomp, threads_decomp, 0, stream>>>(
|
||||
fp_tmp_buffer, lwe_dimension_in, num_samples, base_log,
|
||||
level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
if (uses_trivial_indices) {
|
||||
tgemm<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
||||
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
|
||||
lwe_dimension_out + 1);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
tgemm_with_indices<Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
||||
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
|
||||
lwe_dimension_out + 1, lwe_output_indices);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
// gemm to ks the individual LWEs to GLWEs
|
||||
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
|
||||
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
||||
dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
||||
// Negate all outputs in the LWE
|
||||
keyswitch_negate_with_output_indices<Torus>
|
||||
<<<grid_negate, threads_negate, 0, stream>>>(
|
||||
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
|
||||
num_samples);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void execute_keyswitch_async(
|
||||
CudaStreams streams, const LweArrayVariant<Torus> &lwe_array_out,
|
||||
const LweArrayVariant<Torus> &lwe_output_indexes,
|
||||
const LweArrayVariant<Torus> &lwe_array_in,
|
||||
const LweArrayVariant<Torus> &lwe_input_indexes, Torus *const *ksks,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples, bool uses_trivial_indices,
|
||||
const std::vector<ks_mem<Torus> *> &fp_tmp_buffer) {
|
||||
|
||||
/// If the number of radix blocks is lower than the number of GPUs, not all
|
||||
/// GPUs will be active and there will be 1 input per GPU
|
||||
@@ -164,12 +478,39 @@ void execute_keyswitch_async(CudaStreams streams,
|
||||
Torus *current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
// Compute Keyswitch
|
||||
host_keyswitch_lwe_ciphertext_vector<Torus>(
|
||||
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
|
||||
current_lwe_output_indexes, current_lwe_array_in,
|
||||
current_lwe_input_indexes, ksks[i], lwe_dimension_in, lwe_dimension_out,
|
||||
base_log, level_count, num_samples_on_gpu);
|
||||
if (!fp_tmp_buffer.empty() &&
|
||||
num_samples_on_gpu >= get_threshold_ks_gemm()) {
|
||||
GPU_ASSERT(fp_tmp_buffer.size() >= streams.count(),
|
||||
"GEMM KS Buffers %ld were not initialized for this amount of "
|
||||
"streams, %d",
|
||||
fp_tmp_buffer.size(), streams.count());
|
||||
|
||||
GPU_ASSERT(fp_tmp_buffer[i]->num_lwes >= num_samples_on_gpu,
|
||||
"KS temp buffer not big enough");
|
||||
|
||||
GPU_ASSERT(fp_tmp_buffer[i]->lwe_dimension ==
|
||||
std::max(lwe_dimension_in, lwe_dimension_out),
|
||||
"KS temp buffer was created for a different input LWE size: "
|
||||
"%d vs (in:%d, out:%d)",
|
||||
fp_tmp_buffer[i]->lwe_dimension, lwe_dimension_in,
|
||||
lwe_dimension_out);
|
||||
|
||||
// Compute Keyswitch
|
||||
host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
|
||||
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
|
||||
current_lwe_output_indexes, current_lwe_array_in,
|
||||
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
|
||||
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
|
||||
fp_tmp_buffer[i]->d_buffer, uses_trivial_indices);
|
||||
|
||||
} else {
|
||||
// Compute Keyswitch
|
||||
host_keyswitch_lwe_ciphertext_vector<Torus>(
|
||||
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
|
||||
current_lwe_output_indexes, current_lwe_array_in,
|
||||
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
|
||||
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,4 +600,32 @@ __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
|
||||
ks_mem<Torus> **ks_tmp_memory,
|
||||
uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t num_lwes,
|
||||
bool allocate_gpu_memory) {
|
||||
uint64_t sub_size_tracker = 0;
|
||||
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
|
||||
lwe_dimension_in, lwe_dimension_out, num_lwes);
|
||||
|
||||
*ks_tmp_memory = new ks_mem<Torus>;
|
||||
(*ks_tmp_memory)->d_buffer = (uint64_t *)cuda_malloc_with_size_tracking_async(
|
||||
buffer_size, stream, gpu_index, sub_size_tracker, allocate_gpu_memory);
|
||||
(*ks_tmp_memory)->lwe_dimension =
|
||||
std::max(lwe_dimension_in, lwe_dimension_out);
|
||||
(*ks_tmp_memory)->num_lwes = num_lwes;
|
||||
return sub_size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void cleanup_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
|
||||
ks_mem<Torus> *ks_tmp_memory,
|
||||
bool allocate_gpu_memory) {
|
||||
cuda_drop_with_size_tracking_async(ks_tmp_memory->d_buffer, stream, gpu_index,
|
||||
allocate_gpu_memory);
|
||||
delete ks_tmp_memory;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -18,73 +18,6 @@
|
||||
|
||||
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)
|
||||
|
||||
const int BLOCK_SIZE_DECOMP = 8;
|
||||
|
||||
// Initialize decomposition by performing rounding
|
||||
// and decomposing one level of an array of Torus LWEs. Only
|
||||
// decomposes the mask elements of the incoming LWEs.
|
||||
template <typename Torus>
|
||||
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
|
||||
uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
uint32_t level_count) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
||||
return;
|
||||
|
||||
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
|
||||
// we only decompose the mask. Thus the stride for reading
|
||||
// is lwe_dimension + 1, while for writing it is lwe_dimension
|
||||
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
|
||||
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
auto write_state_idx =
|
||||
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
|
||||
Torus a_i = lwe_in[read_val_idx];
|
||||
|
||||
Torus state = init_decomposer_state(a_i, base_log, level_count);
|
||||
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
__syncthreads();
|
||||
lwe_out[write_state_idx] = state;
|
||||
}
|
||||
|
||||
// Continue decomposition of an array of Torus elements in place. Supposes
|
||||
// that the array contains already decomposed elements and
|
||||
// computes the new decomposed level in place.
|
||||
template <typename Torus>
|
||||
__global__ void
|
||||
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
|
||||
uint32_t num_lwe, uint32_t base_log,
|
||||
uint32_t level_count) {
|
||||
|
||||
// index of this LWE ct in the buffer
|
||||
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// index of the LWE sample in the LWE ct
|
||||
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
||||
return;
|
||||
|
||||
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
||||
auto state_idx = num_lwe * lwe_dimension + val_idx;
|
||||
|
||||
Torus state = buffer_in[state_idx];
|
||||
__syncthreads();
|
||||
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
|
||||
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
__syncthreads();
|
||||
buffer_in[state_idx] = state;
|
||||
}
|
||||
|
||||
// Finish the keyswitching operation and prepare GLWEs for accumulation.
|
||||
// 1. Finish the keyswitching computation partially performed with a GEMM:
|
||||
// - negate the dot product between the GLWE and KSK polynomial
|
||||
@@ -209,9 +142,10 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
|
||||
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
|
||||
"GEMM kernel error: shared memory required might be too large");
|
||||
|
||||
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
|
||||
stride_KSK_buffer, d_mem_1, glwe_accumulator_size);
|
||||
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
|
||||
stride_KSK_buffer, d_mem_1, glwe_accumulator_size);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
auto ksk_block_size = glwe_accumulator_size;
|
||||
@@ -222,10 +156,11 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
|
||||
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
|
||||
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
|
||||
glwe_accumulator_size);
|
||||
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
|
||||
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
||||
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
|
||||
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
|
||||
glwe_accumulator_size);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -547,7 +547,7 @@ __host__ void integer_radix_apply_univariate_lookup_table(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -575,7 +575,8 @@ __host__ void integer_radix_apply_univariate_lookup_table(
|
||||
execute_keyswitch_async<Torus>(
|
||||
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
|
||||
lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -649,7 +650,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -677,7 +678,8 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
|
||||
execute_keyswitch_async<Torus>(
|
||||
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
|
||||
lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -767,7 +769,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
num_radix_blocks, lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -792,7 +794,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
|
||||
execute_keyswitch_async<Torus>(
|
||||
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension,
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks);
|
||||
small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true,
|
||||
lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -1521,7 +1524,8 @@ void host_full_propagate_inplace(CudaStreams streams,
|
||||
streams.get_ith(0), (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
|
||||
mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr,
|
||||
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
|
||||
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1);
|
||||
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1,
|
||||
mem_ptr->lut->using_trivial_lwe_indexes, mem_ptr->lut->ks_tmp_buf_vec);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_small_lwe_vector,
|
||||
@@ -2356,7 +2360,8 @@ __host__ void integer_radix_apply_noise_squashing(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
|
||||
ks_level, lwe_array_out->num_radix_blocks);
|
||||
ks_level, lwe_array_out->num_radix_blocks,
|
||||
lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -2386,7 +2391,7 @@ __host__ void integer_radix_apply_noise_squashing(
|
||||
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, ksks,
|
||||
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
|
||||
ks_level, lwe_array_out->num_radix_blocks);
|
||||
ks_level, lwe_array_out->num_radix_blocks, true, lut->ks_tmp_buf_vec);
|
||||
|
||||
/// int_noise_squashing_lut doesn't support a different output or lut
|
||||
/// indexing than the trivial
|
||||
|
||||
@@ -389,13 +389,16 @@ __host__ void host_integer_partial_sum_ciphertexts_vec(
|
||||
needs_processing);
|
||||
|
||||
auto active_streams = streams.active_gpu_subset(total_ciphertexts);
|
||||
GPU_ASSERT(total_ciphertexts <= mem_ptr->luts_message_carry->num_blocks,
|
||||
"SUM CT");
|
||||
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
|
||||
(Torus *)current_blocks->ptr, d_pbs_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
|
||||
mem_ptr->params.ks_level, total_messages);
|
||||
mem_ptr->params.ks_level, total_messages, false,
|
||||
mem_ptr->luts_message_carry->ks_tmp_buf_vec);
|
||||
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
|
||||
@@ -446,7 +449,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec(
|
||||
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
|
||||
(Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
|
||||
mem_ptr->params.ks_level, num_radix_blocks);
|
||||
mem_ptr->params.ks_level, num_radix_blocks, false,
|
||||
mem_ptr->luts_message_carry->ks_tmp_buf_vec);
|
||||
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
|
||||
|
||||
@@ -59,7 +59,7 @@ void rerand_inplace(
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), ksed_zero_lwes, lwe_trivial_indexes, zero_lwes,
|
||||
lwe_trivial_indexes, ksk, input_dimension, output_dimension, ks_base_log,
|
||||
ks_level, num_lwes);
|
||||
ks_level, num_lwes, true, mem_ptr->ks_tmp_buf_vec);
|
||||
|
||||
// Add ks output to ct
|
||||
// Check sizes
|
||||
|
||||
@@ -102,14 +102,14 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
|
||||
// This code is adapted by generalizing the 1d block-tiling
|
||||
// kernel from https://github.com/siboehm/SGEMM_CUDA
|
||||
// to any matrix dimension
|
||||
template <typename Torus>
|
||||
template <typename Torus, int BLOCK_SIZE, int THREADS>
|
||||
__global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
|
||||
uint stride_B, Torus *C, uint stride_C) {
|
||||
|
||||
const int BM = BLOCK_SIZE_GEMM;
|
||||
const int BN = BLOCK_SIZE_GEMM;
|
||||
const int BK = THREADS_GEMM;
|
||||
const int TM = THREADS_GEMM;
|
||||
const int BM = BLOCK_SIZE;
|
||||
const int BN = BLOCK_SIZE;
|
||||
const int BK = THREADS;
|
||||
const int TM = THREADS;
|
||||
|
||||
const uint cRow = blockIdx.y;
|
||||
const uint cCol = blockIdx.x;
|
||||
@@ -192,4 +192,103 @@ __global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply matrices A, B of size (M, K), (K, N) respectively
|
||||
// with K as the inner dimension.
|
||||
//
|
||||
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
|
||||
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
|
||||
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
|
||||
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
|
||||
//
|
||||
// This code is adapted by generalizing the 1d block-tiling
|
||||
// kernel from https://github.com/siboehm/SGEMM_CUDA
|
||||
// to any matrix dimension
|
||||
template <typename Torus, int BLOCK_SIZE, int THREADS>
|
||||
__global__ void tgemm_with_indices(uint M, uint N, uint K, const Torus *A,
|
||||
const Torus *B, uint stride_B, Torus *C,
|
||||
uint stride_C,
|
||||
const Torus *__restrict__ C_indices) {
|
||||
|
||||
const int BM = BLOCK_SIZE;
|
||||
const int BN = BLOCK_SIZE;
|
||||
const int BK = THREADS;
|
||||
const int TM = THREADS;
|
||||
|
||||
const uint cRow = blockIdx.y;
|
||||
const uint cCol = blockIdx.x;
|
||||
|
||||
const int threadCol = threadIdx.x % BN;
|
||||
const int threadRow = threadIdx.x / BN;
|
||||
|
||||
// Allocate space for the current block tile in shared memory
|
||||
__shared__ Torus As[BM * BK];
|
||||
__shared__ Torus Bs[BK * BN];
|
||||
|
||||
// Initialize the pointers to the input blocks from A, B
|
||||
// Tiles from these blocks are loaded to shared memory
|
||||
|
||||
A += cRow * BM * K;
|
||||
B += cCol * BN;
|
||||
|
||||
// Each thread will handle multiple sub-blocks
|
||||
const uint innerColA = threadIdx.x % BK;
|
||||
const uint innerRowA = threadIdx.x / BK;
|
||||
const uint innerColB = threadIdx.x % BN;
|
||||
const uint innerRowB = threadIdx.x / BN;
|
||||
|
||||
// allocate thread-local cache for results in registerfile
|
||||
Torus threadResults[TM] = {0};
|
||||
|
||||
auto row_A = cRow * BM + innerRowA;
|
||||
auto col_B = cCol * BN + innerColB;
|
||||
|
||||
// For each thread, loop over block tiles
|
||||
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
|
||||
auto col_A = bkIdx + innerColA;
|
||||
auto row_B = bkIdx + innerRowB;
|
||||
|
||||
if (row_A < M && col_A < K) {
|
||||
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; //
|
||||
} else {
|
||||
As[innerRowA * BK + innerColA] = 0;
|
||||
}
|
||||
|
||||
if (col_B < N && row_B < K) {
|
||||
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
|
||||
} else {
|
||||
Bs[innerRowB * BN + innerColB] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Advance blocktile for the next iteration of this loop
|
||||
A += BK;
|
||||
B += BK * stride_B;
|
||||
|
||||
// calculate per-thread results
|
||||
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
|
||||
// we make the dotproduct loop the outside loop, which facilitates
|
||||
// reuse of the Bs entry, which we can cache in a tmp var.
|
||||
Torus tmp = Bs[dotIdx * BN + threadCol];
|
||||
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
|
||||
threadResults[resIdx] +=
|
||||
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// write out the results
|
||||
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
|
||||
int outRow = cRow * BM + threadRow * TM + resIdx;
|
||||
int outCol = cCol * BN + threadCol;
|
||||
|
||||
if (outRow >= M)
|
||||
continue;
|
||||
if (outCol >= N)
|
||||
continue;
|
||||
|
||||
C[C_indices[outRow] * stride_C + outCol] += threadResults[resIdx];
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CUDA_MULT_H
|
||||
|
||||
@@ -89,9 +89,10 @@ __host__ void host_wrapping_polynomial_mul_one_to_many(
|
||||
PANIC("GEMM kernel error: shared memory required might be too large");
|
||||
|
||||
// Write the output with a stride of the GLWE total number of values
|
||||
tgemm<Torus><<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
|
||||
n_rhs, polynomial_size, polynomial_size, poly_rhs, (Torus *)circulant,
|
||||
polynomial_size, result, (polynomial_size * (glwe_dimension + 1)));
|
||||
tgemm<Torus, BLOCK_SIZE_GEMM, THREADS_GEMM>
|
||||
<<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
|
||||
n_rhs, polynomial_size, polynomial_size, poly_rhs, (Torus *)circulant,
|
||||
polynomial_size, result, (polynomial_size * (glwe_dimension + 1)));
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define HELPER_CUH
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <stdio.h>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -64,4 +65,71 @@ void print_body(const char *name, T *src, int n, int lwe_dimension, T delta) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void print_2d_csv_to_file(const std::vector<Torus> &v, int col_size,
|
||||
const char *fname) {
|
||||
FILE *fp = fopen(fname, "wt");
|
||||
for (int i = 0; i < v.size() / col_size; ++i) {
|
||||
for (int j = 0; j < col_size; ++j) {
|
||||
fprintf(fp, "%lu%c", v[i * col_size + j],
|
||||
(j == col_size - 1) ? '\n' : ',');
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void dump_2d_gpu_to_file(const Torus *ptr, int row_size, int col_size,
|
||||
const char *fname_prefix, int rand_prefix,
|
||||
cudaStream_t stream, uint32_t gpu_index) {
|
||||
// #ifndef NDEBUG
|
||||
std::vector<Torus> buf_cpu(row_size * col_size);
|
||||
|
||||
char fname[4096];
|
||||
snprintf(fname, 4096, "%s_%d_%d_%d.csv", fname_prefix, row_size, col_size,
|
||||
rand_prefix);
|
||||
|
||||
cuda_memcpy_async_to_cpu((void *)&buf_cpu[0], ptr,
|
||||
buf_cpu.size() * sizeof(Torus), stream, gpu_index);
|
||||
cuda_synchronize_device(gpu_index);
|
||||
print_2d_csv_to_file(buf_cpu, col_size, fname);
|
||||
// #endif
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void compare_2d_arrays(const Torus *ptr1, const Torus *ptr2,
|
||||
int row_size, int col_size, cudaStream_t stream,
|
||||
uint32_t gpu_index) {
|
||||
// #ifndef NDEBUG
|
||||
std::vector<Torus> buf_cpu1(row_size * col_size),
|
||||
buf_cpu2(row_size * col_size);
|
||||
;
|
||||
cuda_memcpy_async_to_cpu((void *)&buf_cpu1[0], ptr1,
|
||||
buf_cpu1.size() * sizeof(Torus), stream, gpu_index);
|
||||
cuda_memcpy_async_to_cpu((void *)&buf_cpu2[0], ptr2,
|
||||
buf_cpu2.size() * sizeof(Torus), stream, gpu_index);
|
||||
cuda_synchronize_device(gpu_index);
|
||||
|
||||
std::vector<uint32_t> non_matching_indexes;
|
||||
for (int i = 0; i < buf_cpu1.size(); ++i) {
|
||||
if (buf_cpu1[i] != buf_cpu2[i]) {
|
||||
non_matching_indexes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!non_matching_indexes.empty()) {
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < std::min(non_matching_indexes.size(), (size_t)10);
|
||||
++i) {
|
||||
ss << " difference at " << non_matching_indexes[i] << ": "
|
||||
<< buf_cpu1[non_matching_indexes[i]] << " vs "
|
||||
<< buf_cpu2[non_matching_indexes[i]] << " at index "
|
||||
<< non_matching_indexes[i] << "\n";
|
||||
}
|
||||
GPU_ASSERT(non_matching_indexes.empty(),
|
||||
"Correctness error for matrices %d x %d: \n%s", row_size,
|
||||
col_size, ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -79,7 +79,8 @@ __host__ void host_expand_without_verification(
|
||||
streams.get_ith(0), ksed_small_to_big_expanded_lwes,
|
||||
lwe_trivial_indexes_vec[0], expanded_lwes, lwe_trivial_indexes_vec[0],
|
||||
casting_keys, casting_input_dimension, casting_output_dimension,
|
||||
casting_ks_base_log, casting_ks_level, num_lwes);
|
||||
casting_ks_base_log, casting_ks_level, num_lwes,
|
||||
lut->using_trivial_lwe_indexes, lut->ks_tmp_buf_vec);
|
||||
|
||||
// In this case, the next keyswitch will use the compute ksk
|
||||
ksks = compute_ksks;
|
||||
|
||||
@@ -50,8 +50,8 @@ void keyswitch_setup(cudaStream_t stream, uint32_t gpu_index, Seed *seed,
|
||||
uint64_t **d_lwe_ct_in_array,
|
||||
uint64_t **d_lwe_input_indexes,
|
||||
uint64_t **d_lwe_ct_out_array,
|
||||
uint64_t **d_lwe_output_indexes, int input_lwe_dimension,
|
||||
int output_lwe_dimension,
|
||||
uint64_t **d_lwe_output_indexes, void **ks_tmp_buffer,
|
||||
int input_lwe_dimension, int output_lwe_dimension,
|
||||
DynamicDistribution lwe_noise_distribution,
|
||||
int ksk_base_log, int ksk_level, int message_modulus,
|
||||
int carry_modulus, int *payload_modulus, uint64_t *delta,
|
||||
@@ -62,7 +62,7 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index,
|
||||
uint64_t *d_lwe_ct_in_array,
|
||||
uint64_t *lwe_input_indexes,
|
||||
uint64_t *d_lwe_ct_out_array,
|
||||
uint64_t *lwe_output_indexes);
|
||||
uint64_t *lwe_output_indexes, void **ks_tmp_buffer);
|
||||
|
||||
void fft_setup(cudaStream_t stream, uint32_t gpu_index, double **poly1,
|
||||
double **poly2, double2 **h_cpoly1, double2 **h_cpoly2,
|
||||
|
||||
@@ -260,7 +260,7 @@ void keyswitch_setup(
|
||||
cudaStream_t stream, uint32_t gpu_index, Seed *seed, uint64_t **lwe_sk_in_array,
|
||||
uint64_t **lwe_sk_out_array, uint64_t **d_ksk_array, uint64_t **plaintexts,
|
||||
uint64_t **d_lwe_ct_in_array, uint64_t **d_lwe_input_indexes,
|
||||
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes,
|
||||
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes, void** ks_tmp_buffer,
|
||||
int input_lwe_dimension, int output_lwe_dimension,
|
||||
DynamicDistribution lwe_noise_distribution, int ksk_base_log, int ksk_level,
|
||||
int message_modulus, int carry_modulus, int *payload_modulus,
|
||||
@@ -295,6 +295,11 @@ void keyswitch_setup(
|
||||
uint64_t *lwe_ct_in_array =
|
||||
(uint64_t *)malloc((input_lwe_dimension + 1) * number_of_inputs *
|
||||
repetitions * samples * sizeof(uint64_t));
|
||||
|
||||
scratch_cuda_keyswitch_gemm_64(
|
||||
(void*)stream, gpu_index, ks_tmp_buffer,
|
||||
input_lwe_dimension, output_lwe_dimension, number_of_inputs * repetitions * samples, true);
|
||||
|
||||
// Create the input/output ciphertexts
|
||||
for (int r = 0; r < repetitions; r++) {
|
||||
uint64_t *lwe_sk_in =
|
||||
@@ -343,7 +348,7 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index, uint64_t *lwe_s
|
||||
uint64_t *plaintexts, uint64_t *d_lwe_ct_in_array,
|
||||
uint64_t *d_lwe_input_indexes,
|
||||
uint64_t *d_lwe_ct_out_array,
|
||||
uint64_t *d_lwe_output_indexes) {
|
||||
uint64_t *d_lwe_output_indexes, void** ks_tmp_buffer) {
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
|
||||
free(lwe_sk_in_array);
|
||||
@@ -355,8 +360,14 @@ void keyswitch_teardown(cudaStream_t stream, uint32_t gpu_index, uint64_t *lwe_s
|
||||
cuda_drop_async(d_lwe_ct_out_array, stream, gpu_index);
|
||||
cuda_drop_async(d_lwe_input_indexes, stream, gpu_index);
|
||||
cuda_drop_async(d_lwe_output_indexes, stream, gpu_index);
|
||||
|
||||
cleanup_cuda_keyswitch_gemm_64((void*)stream, gpu_index,
|
||||
ks_tmp_buffer,
|
||||
true);
|
||||
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
cuda_destroy_stream(stream, gpu_index);
|
||||
|
||||
}
|
||||
|
||||
void fft_setup(cudaStream_t stream, uint32_t gpu_index, double **_poly1, double **_poly2,
|
||||
|
||||
@@ -45,6 +45,7 @@ protected:
|
||||
uint64_t *lwe_out_ct;
|
||||
uint64_t *lwe_input_indexes;
|
||||
uint64_t *lwe_output_indexes;
|
||||
void *ks_tmp_buffer;
|
||||
|
||||
// Data stays at gpu 0
|
||||
uint32_t gpu_index = 0;
|
||||
@@ -83,7 +84,7 @@ public:
|
||||
keyswitch_setup(streams[0], gpu_index, &seed, &lwe_sk_in_array,
|
||||
&lwe_sk_out_array, &d_ksk_array, &plaintexts,
|
||||
&d_lwe_ct_in_array, &lwe_input_indexes, &d_lwe_ct_out_array,
|
||||
&lwe_output_indexes, input_lwe_dimension,
|
||||
&lwe_output_indexes, &ks_tmp_buffer, input_lwe_dimension,
|
||||
output_lwe_dimension, noise_distribution, ksk_base_log,
|
||||
ksk_level, message_modulus, carry_modulus, &payload_modulus,
|
||||
&delta, number_of_inputs, REPETITIONS, SAMPLES);
|
||||
@@ -94,7 +95,7 @@ public:
|
||||
keyswitch_teardown(streams[0], gpu_index, lwe_sk_in_array, lwe_sk_out_array,
|
||||
d_ksk_array, plaintexts, d_lwe_ct_in_array,
|
||||
lwe_input_indexes, d_lwe_ct_out_array,
|
||||
lwe_output_indexes);
|
||||
lwe_output_indexes, &ks_tmp_buffer);
|
||||
if (active_gpu_count > 1) {
|
||||
for (uint gpu_i = 1; gpu_i < active_gpu_count; gpu_i++) {
|
||||
cuda_destroy_stream(streams[gpu_i], gpu_i);
|
||||
@@ -136,10 +137,11 @@ TEST_P(KeyswitchMultiGPUTestPrimitives_u64, keyswitch) {
|
||||
d_lwe_ct_out_array + (ptrdiff_t)(output_lwe_start_index);
|
||||
|
||||
// Execute keyswitch
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
streams[gpu_i], gpu_i, d_lwe_ct_out, lwe_output_indexes,
|
||||
d_lwe_ct_in_slice, lwe_input_indexes, d_ksk, input_lwe_dimension,
|
||||
output_lwe_dimension, ksk_base_log, ksk_level, num_inputs);
|
||||
output_lwe_dimension, ksk_base_log, ksk_level, num_inputs,
|
||||
ks_tmp_buffer, false);
|
||||
}
|
||||
for (uint gpu_i = 0; gpu_i < active_gpu_count; gpu_i++) {
|
||||
cuda_synchronize_stream(streams[gpu_i], gpu_i);
|
||||
@@ -195,6 +197,7 @@ protected:
|
||||
uint64_t *lwe_out_ct;
|
||||
uint64_t *lwe_input_indexes;
|
||||
uint64_t *lwe_output_indexes;
|
||||
void *ks_tmp_buffer;
|
||||
|
||||
public:
|
||||
// Test arithmetic functions
|
||||
@@ -217,7 +220,7 @@ public:
|
||||
keyswitch_setup(stream, gpu_index, &seed, &lwe_sk_in_array,
|
||||
&lwe_sk_out_array, &d_ksk_array, &plaintexts,
|
||||
&d_lwe_ct_in_array, &lwe_input_indexes, &d_lwe_ct_out_array,
|
||||
&lwe_output_indexes, input_lwe_dimension,
|
||||
&lwe_output_indexes, &ks_tmp_buffer, input_lwe_dimension,
|
||||
output_lwe_dimension, noise_distribution, ksk_base_log,
|
||||
ksk_level, message_modulus, carry_modulus, &payload_modulus,
|
||||
&delta, number_of_inputs, REPETITIONS, SAMPLES);
|
||||
@@ -227,7 +230,7 @@ public:
|
||||
keyswitch_teardown(stream, gpu_index, lwe_sk_in_array, lwe_sk_out_array,
|
||||
d_ksk_array, plaintexts, d_lwe_ct_in_array,
|
||||
lwe_input_indexes, d_lwe_ct_out_array,
|
||||
lwe_output_indexes);
|
||||
lwe_output_indexes, &ks_tmp_buffer);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -245,11 +248,12 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
|
||||
(ptrdiff_t)((r * SAMPLES * number_of_inputs + s * number_of_inputs) *
|
||||
(input_lwe_dimension + 1));
|
||||
// Execute keyswitch
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
stream, gpu_index, (void *)d_lwe_ct_out_array,
|
||||
(void *)lwe_output_indexes, (void *)d_lwe_ct_in,
|
||||
(void *)lwe_input_indexes, (void *)d_ksk, input_lwe_dimension,
|
||||
output_lwe_dimension, ksk_base_log, ksk_level, number_of_inputs);
|
||||
output_lwe_dimension, ksk_base_log, ksk_level, number_of_inputs,
|
||||
ks_tmp_buffer, false);
|
||||
|
||||
// Copy result back
|
||||
cuda_memcpy_async_to_cpu(lwe_out_ct, d_lwe_ct_out_array,
|
||||
|
||||
@@ -2519,6 +2519,24 @@ unsafe extern "C" {
|
||||
num_samples: u32,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
lwe_array_out: *mut ffi::c_void,
|
||||
lwe_output_indexes: *const ffi::c_void,
|
||||
lwe_array_in: *const ffi::c_void,
|
||||
lwe_input_indexes: *const ffi::c_void,
|
||||
ksk: *const ffi::c_void,
|
||||
lwe_dimension_in: u32,
|
||||
lwe_dimension_out: u32,
|
||||
base_log: u32,
|
||||
level_count: u32,
|
||||
num_samples: u32,
|
||||
ks_tmp_buffer: *const ffi::c_void,
|
||||
uses_trivial_indexes: bool,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
stream: *mut ffi::c_void,
|
||||
@@ -2547,6 +2565,25 @@ unsafe extern "C" {
|
||||
allocate_gpu_memory: bool,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_keyswitch_gemm_64(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
ks_tmp_memory: *mut *mut ffi::c_void,
|
||||
lwe_dimension_in: u32,
|
||||
lwe_dimension_out: u32,
|
||||
num_lwes: u32,
|
||||
allocate_gpu_memory: bool,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_keyswitch_gemm_64(
|
||||
stream: *mut ffi::c_void,
|
||||
gpu_index: u32,
|
||||
ks_tmp_memory: *mut *mut ffi::c_void,
|
||||
allocate_gpu_memory: bool,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_packing_keyswitch_lwe_list_to_glwe_64(
|
||||
stream: *mut ffi::c_void,
|
||||
|
||||
@@ -386,7 +386,7 @@ mod cuda {
|
||||
.keyswitch_key(ksk_big_to_small)
|
||||
.build();
|
||||
|
||||
let bench_id;
|
||||
let mut bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
@@ -423,120 +423,163 @@ mod cuda {
|
||||
&mut output_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
|
||||
black_box(&mut ct_gpu);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{name}");
|
||||
let blocks: usize = 1;
|
||||
let elements = throughput_num_threads(blocks, 1);
|
||||
let elements_per_stream = elements as usize / gpu_count;
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
bench_group.sample_size(50);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams_core();
|
||||
|
||||
let plaintext_list = PlaintextList::new(
|
||||
Scalar::ZERO,
|
||||
PlaintextCount(elements_per_stream),
|
||||
for uses_gemm_ks in [false, true] {
|
||||
for uses_simple_indices in [false, true] {
|
||||
let indices_str = if uses_simple_indices {
|
||||
"simple"
|
||||
} else {
|
||||
"complex"
|
||||
};
|
||||
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
|
||||
bench_id = format!(
|
||||
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
|
||||
);
|
||||
|
||||
let input_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
let blocks: usize = 256;
|
||||
let elements = gpu_count * blocks;
|
||||
let elements_per_stream = elements / gpu_count;
|
||||
bench_group.throughput(Throughput::Elements(elements as u64));
|
||||
bench_group.sample_size(50);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let local_streams = cuda_local_streams_core();
|
||||
|
||||
let plaintext_list = PlaintextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
PlaintextCount(elements_per_stream),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let input_ks_list = LweCiphertextList::from_container(
|
||||
input_ct_list.into_container(),
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&input_ks_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let input_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let input_ks_list = LweCiphertextList::from_container(
|
||||
input_ct_list.into_container(),
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&input_ks_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let h_indexes = (0..(elements / gpu_count as u64))
|
||||
.map(CastFrom::cast_from)
|
||||
.collect::<Vec<_>>();
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
|
||||
.collect::<Vec<_>>();
|
||||
local_streams.iter().for_each(|stream| stream.synchronize());
|
||||
let output_cts = (0..gpu_count)
|
||||
.map(|i| {
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
&local_streams[i],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(input_cts, output_cts, cuda_indexes_vec, local_streams)
|
||||
};
|
||||
let indexes_range: Vec<u64> = if uses_simple_indices {
|
||||
(0..(elements / gpu_count) as u64).collect()
|
||||
} else {
|
||||
(0..(elements / gpu_count) as u64).rev().collect()
|
||||
};
|
||||
let h_indexes = indexes_range
|
||||
.iter()
|
||||
.map(|v| CastFrom::cast_from(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
b.iter_batched(
|
||||
setup_encrypted_values,
|
||||
|(input_cts, mut output_cts, cuda_indexes_vec, local_streams)| {
|
||||
(0..gpu_count)
|
||||
.into_par_iter()
|
||||
.zip(input_cts.par_iter())
|
||||
.zip(output_cts.par_iter_mut())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(|(((i, input_ct), output_ct), local_stream)| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
gpu_keys_vec[i].ksk.as_ref().unwrap(),
|
||||
input_ct,
|
||||
output_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
local_stream,
|
||||
);
|
||||
})
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
|
||||
.collect::<Vec<_>>();
|
||||
local_streams.iter().for_each(|stream| stream.synchronize());
|
||||
|
||||
(input_cts, output_cts, cuda_indexes_vec, local_streams)
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup_encrypted_values,
|
||||
|(
|
||||
input_cts,
|
||||
mut output_cts,
|
||||
cuda_indexes_vec,
|
||||
local_streams,
|
||||
)| {
|
||||
(0..gpu_count)
|
||||
.into_par_iter()
|
||||
.zip(input_cts.par_iter())
|
||||
.zip(output_cts.par_iter_mut())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(
|
||||
|(((i, input_ct), output_ct), local_stream)| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
gpu_keys_vec[i].ksk.as_ref().unwrap(),
|
||||
input_ct,
|
||||
output_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
uses_simple_indices,
|
||||
local_stream,
|
||||
uses_gemm_ks,
|
||||
);
|
||||
},
|
||||
)
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
*params,
|
||||
name,
|
||||
"ks",
|
||||
&OperatorType::Atomic,
|
||||
bit_size,
|
||||
vec![bit_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -630,7 +630,9 @@ mod cuda {
|
||||
&mut output_ks_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
cuda_programmable_bootstrap_lwe_ciphertext(
|
||||
&output_ks_ct_gpu,
|
||||
@@ -782,7 +784,9 @@ mod cuda {
|
||||
output_ks_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
true,
|
||||
local_stream,
|
||||
false,
|
||||
);
|
||||
cuda_programmable_bootstrap_lwe_ciphertext(
|
||||
output_ks_ct,
|
||||
@@ -937,7 +941,9 @@ mod cuda {
|
||||
&mut output_ks_ct_gpu,
|
||||
&cuda_indexes.d_input,
|
||||
&cuda_indexes.d_output,
|
||||
true,
|
||||
&streams,
|
||||
false,
|
||||
);
|
||||
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
&output_ks_ct_gpu,
|
||||
@@ -1088,7 +1094,9 @@ mod cuda {
|
||||
output_ks_ct,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
true,
|
||||
local_stream,
|
||||
false,
|
||||
);
|
||||
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
output_ks_ct,
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::vec::CudaVec;
|
||||
use crate::core_crypto::gpu::{keyswitch_async, CudaStreams};
|
||||
use crate::core_crypto::gpu::{
|
||||
keyswitch_async, keyswitch_async_gemm, scratch_cuda_keyswitch_gemm_64, CudaStreams,
|
||||
};
|
||||
use crate::core_crypto::prelude::UnsignedInteger;
|
||||
use std::cmp::min;
|
||||
use tfhe_cuda_backend::bindings::cleanup_cuda_keyswitch_gemm_64;
|
||||
use tfhe_cuda_backend::ffi;
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronised
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
uses_trivial_indices: bool,
|
||||
streams: &CudaStreams,
|
||||
use_gemm_ks: bool,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
@@ -70,28 +78,76 @@ pub unsafe fn cuda_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
output_indexes.gpu_index(0).get(),
|
||||
);
|
||||
|
||||
keyswitch_async(
|
||||
streams,
|
||||
&mut output_lwe_ciphertext.0.d_vec,
|
||||
output_indexes,
|
||||
&input_lwe_ciphertext.0.d_vec,
|
||||
input_indexes,
|
||||
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
&lwe_keyswitch_key.d_vec,
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
input_lwe_ciphertext.lwe_ciphertext_count().0 as u32,
|
||||
let mut ks_tmp_buffer: *mut ffi::c_void = std::ptr::null_mut();
|
||||
|
||||
let num_lwes_to_ks = min(
|
||||
input_indexes.len,
|
||||
input_lwe_ciphertext.lwe_ciphertext_count().0,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input_indexes.len, output_indexes.len,
|
||||
"The number of input and output indexes must be the same for LWE keyswitch"
|
||||
);
|
||||
|
||||
if use_gemm_ks {
|
||||
cuda_scratch_keyswitch_lwe_ciphertext_async::<Scalar>(
|
||||
streams,
|
||||
std::ptr::addr_of_mut!(ks_tmp_buffer),
|
||||
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension().0 as u32,
|
||||
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension().0 as u32,
|
||||
num_lwes_to_ks as u32,
|
||||
true,
|
||||
);
|
||||
|
||||
keyswitch_async_gemm(
|
||||
streams,
|
||||
&mut output_lwe_ciphertext.0.d_vec,
|
||||
output_indexes,
|
||||
&input_lwe_ciphertext.0.d_vec,
|
||||
input_indexes,
|
||||
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
&lwe_keyswitch_key.d_vec,
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
num_lwes_to_ks as u32,
|
||||
ks_tmp_buffer,
|
||||
uses_trivial_indices,
|
||||
);
|
||||
|
||||
cleanup_cuda_keyswitch_async::<Scalar>(
|
||||
streams,
|
||||
std::ptr::addr_of_mut!(ks_tmp_buffer),
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
keyswitch_async(
|
||||
streams,
|
||||
&mut output_lwe_ciphertext.0.d_vec,
|
||||
output_indexes,
|
||||
&input_lwe_ciphertext.0.d_vec,
|
||||
input_indexes,
|
||||
lwe_keyswitch_key.input_key_lwe_size().to_lwe_dimension(),
|
||||
lwe_keyswitch_key.output_key_lwe_size().to_lwe_dimension(),
|
||||
&lwe_keyswitch_key.d_vec,
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
num_lwes_to_ks as u32,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
|
||||
lwe_keyswitch_key: &CudaLweKeyswitchKey<Scalar>,
|
||||
input_lwe_ciphertext: &CudaLweCiphertextList<Scalar>,
|
||||
output_lwe_ciphertext: &mut CudaLweCiphertextList<Scalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
uses_trivial_indices: bool,
|
||||
streams: &CudaStreams,
|
||||
use_gemm_ks: bool,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
@@ -102,8 +158,52 @@ pub fn cuda_keyswitch_lwe_ciphertext<Scalar>(
|
||||
output_lwe_ciphertext,
|
||||
input_indexes,
|
||||
output_indexes,
|
||||
uses_trivial_indices,
|
||||
streams,
|
||||
use_gemm_ks,
|
||||
);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronized
|
||||
pub unsafe fn cuda_scratch_keyswitch_lwe_ciphertext_async<Scalar>(
|
||||
streams: &CudaStreams,
|
||||
ks_tmp_buffer: *mut *mut ffi::c_void,
|
||||
lwe_dimension_in: u32,
|
||||
lwe_dimension_out: u32,
|
||||
num_lwes: u32,
|
||||
allocate_gpu_memory: bool,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
scratch_cuda_keyswitch_gemm_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
ks_tmp_buffer,
|
||||
lwe_dimension_in,
|
||||
lwe_dimension_out,
|
||||
num_lwes,
|
||||
allocate_gpu_memory,
|
||||
);
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
|
||||
/// be dropped until stream is synchronized
|
||||
pub unsafe fn cleanup_cuda_keyswitch_async<Scalar>(
|
||||
streams: &CudaStreams,
|
||||
ks_tmp_buffer: *mut *mut ffi::c_void,
|
||||
allocate_gpu_memory: bool,
|
||||
) {
|
||||
cleanup_cuda_keyswitch_gemm_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
ks_tmp_buffer,
|
||||
allocate_gpu_memory,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
|
||||
use crate::core_crypto::gpu::{cuda_keyswitch_lwe_ciphertext, CudaStreams};
|
||||
use crate::core_crypto::prelude::misc::check_encrypted_content_respects_mod;
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
params: ClassicTestParams<Scalar>,
|
||||
@@ -19,6 +21,57 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
let ks_decomp_base_log = params.ks_base_log;
|
||||
let ks_decomp_level_count = params.ks_level;
|
||||
|
||||
base_lwe_encrypt_ks_decrypt_custom_mod(
|
||||
lwe_dimension,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
message_modulus_log,
|
||||
encoding_with_padding,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
);
|
||||
}
|
||||
|
||||
fn lwe_encrypt_ks_decrypt_custom_mod_mb<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
params: &MultiBitTestParams<Scalar>,
|
||||
) {
|
||||
let lwe_dimension = params.input_lwe_dimension;
|
||||
let lwe_noise_distribution = DynamicDistribution::new_gaussian_from_std_dev(StandardDev(0f64));
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let ks_decomp_base_log = params.decomp_base_log;
|
||||
let ks_decomp_level_count = params.decomp_level_count;
|
||||
|
||||
base_lwe_encrypt_ks_decrypt_custom_mod(
|
||||
lwe_dimension,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
message_modulus_log,
|
||||
encoding_with_padding,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
ks_decomp_base_log,
|
||||
ks_decomp_level_count,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn base_lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
lwe_dimension: LweDimension,
|
||||
lwe_noise_distribution: DynamicDistribution<Scalar>,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
message_modulus_log: MessageModulusLog,
|
||||
encoding_with_padding: Scalar,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
ks_decomp_base_log: DecompositionBaseLog,
|
||||
ks_decomp_level_count: DecompositionLevelCount,
|
||||
) {
|
||||
let stream = CudaStreams::new_single_gpu(GpuIndex::new(0));
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
@@ -61,65 +114,144 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
msg = msg.wrapping_sub(Scalar::ONE);
|
||||
for _ in 0..NB_TESTS {
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
for test_idx in 0..NB_TESTS {
|
||||
let num_blocks = test_idx * test_idx * 3 + 1;
|
||||
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
plaintext,
|
||||
lwe_noise_distribution,
|
||||
let plaintext_list = (0..num_blocks)
|
||||
.map(|i| (Scalar::cast_from(i) % msg_modulus) * delta)
|
||||
.collect_vec();
|
||||
|
||||
let plaintext_list = PlaintextList::from_container(plaintext_list);
|
||||
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(num_blocks),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
lwe_noise_distribution,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
let input_ct_list_gpu =
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(&input_ct_list, &stream);
|
||||
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
ksk_big_to_small.output_key_lwe_dimension().to_lwe_size(),
|
||||
input_ct_list.lwe_ciphertext_count(),
|
||||
ksk_big_to_small.ciphertext_modulus(),
|
||||
);
|
||||
let mut output_ct_list_gpu =
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(&output_ct_list, &stream);
|
||||
let mut output_ct_list_gpu_gemm =
|
||||
CudaLweCiphertextList::from_lwe_ciphertext_list(&output_ct_list, &stream);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&ct,
|
||||
&input_ct_list,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let d_ct = CudaLweCiphertextList::from_lwe_ciphertext(&ct, &stream);
|
||||
let mut d_output_ct = CudaLweCiphertextList::new(
|
||||
ksk_big_to_small.output_key_lwe_dimension(),
|
||||
LweCiphertextCount(1),
|
||||
ciphertext_modulus,
|
||||
&stream,
|
||||
);
|
||||
let num_blocks = d_ct.0.lwe_ciphertext_count.0;
|
||||
let use_trivial_indexes = test_idx % 2 == 0;
|
||||
|
||||
let num_blocks_to_ks = if use_trivial_indexes {
|
||||
if test_idx % 4 == 0 {
|
||||
num_blocks
|
||||
} else {
|
||||
num_blocks / 2
|
||||
}
|
||||
} else {
|
||||
num_blocks
|
||||
};
|
||||
let lwe_indexes_usize = (0..num_blocks).collect_vec();
|
||||
let lwe_indexes = lwe_indexes_usize
|
||||
let mut lwe_indexes = lwe_indexes_usize.clone();
|
||||
|
||||
let mut lwe_indexes_out = lwe_indexes.clone();
|
||||
|
||||
if !use_trivial_indexes {
|
||||
lwe_indexes.shuffle(&mut thread_rng());
|
||||
lwe_indexes_out.shuffle(&mut thread_rng());
|
||||
}
|
||||
|
||||
let h_lwe_indexes: Vec<Scalar> = lwe_indexes
|
||||
.iter()
|
||||
.take(num_blocks_to_ks)
|
||||
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
|
||||
.collect_vec();
|
||||
let h_lwe_indexes_out: Vec<Scalar> = lwe_indexes_out
|
||||
.iter()
|
||||
.take(num_blocks_to_ks)
|
||||
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
|
||||
.collect_vec();
|
||||
|
||||
let mut d_input_indexes =
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks_to_ks, &stream, 0) };
|
||||
let mut d_output_indexes =
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks, &stream, 0) };
|
||||
unsafe { d_input_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
|
||||
unsafe { d_output_indexes.copy_from_cpu_async(&lwe_indexes, &stream, 0) };
|
||||
unsafe { CudaVec::<Scalar>::new_async(num_blocks_to_ks, &stream, 0) };
|
||||
unsafe { d_input_indexes.copy_from_cpu_async(&h_lwe_indexes, &stream, 0) };
|
||||
unsafe { d_output_indexes.copy_from_cpu_async(&h_lwe_indexes_out, &stream, 0) };
|
||||
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&d_ksk_big_to_small,
|
||||
&d_ct,
|
||||
&mut d_output_ct,
|
||||
&input_ct_list_gpu,
|
||||
&mut output_ct_list_gpu,
|
||||
&d_input_indexes,
|
||||
&d_output_indexes,
|
||||
use_trivial_indexes,
|
||||
&stream,
|
||||
false,
|
||||
);
|
||||
|
||||
let output_ct = d_output_ct.into_lwe_ciphertext(&stream);
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
&d_ksk_big_to_small,
|
||||
&input_ct_list_gpu,
|
||||
&mut output_ct_list_gpu_gemm,
|
||||
&d_input_indexes,
|
||||
&d_output_indexes,
|
||||
use_trivial_indexes,
|
||||
&stream,
|
||||
true,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&output_ct,
|
||||
ciphertext_modulus
|
||||
));
|
||||
// Fill in the expected output: only the LWEs corresponding to output indices
|
||||
// will be non-zero. The test checks that the others remain 0
|
||||
let mut ref_vec = vec![Scalar::ZERO; num_blocks];
|
||||
for i in 0..num_blocks_to_ks {
|
||||
ref_vec[lwe_indexes_out[i]] =
|
||||
round_decode(*plaintext_list.get(lwe_indexes[i]).0, delta);
|
||||
}
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
|
||||
assert_eq!(output_ct_list_gpu.lwe_ciphertext_count().0, num_blocks);
|
||||
// The output has `n_blocks` LWEs but only some are actually set - those
|
||||
// that correspond to output indices. We loop over all LWEs in the output buffer
|
||||
let output_ct_list_cpu = output_ct_list_gpu_gemm.to_lwe_ciphertext_list(&stream);
|
||||
output_ct_list_gpu
|
||||
.to_lwe_ciphertext_list(&stream)
|
||||
.iter()
|
||||
.zip(0..num_blocks)
|
||||
.for_each(|(lwe_ct_out, i)| {
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&lwe_ct_out,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out);
|
||||
|
||||
assert_eq!(msg, decoded);
|
||||
let lwe_ct_out_gemm = output_ct_list_cpu.get(i);
|
||||
let decrypted_gemm = decrypt_lwe_ciphertext(&lwe_sk, &lwe_ct_out_gemm);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
let decoded_gemm = round_decode(decrypted_gemm.0, delta) % msg_modulus;
|
||||
|
||||
assert_eq!(ref_vec[i], decoded);
|
||||
assert_eq!(ref_vec[i], decoded_gemm);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create_gpu_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod);
|
||||
create_gpu_multi_bit_parameterized_test!(lwe_encrypt_ks_decrypt_custom_mod_mb);
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::any::{Any, TypeId};
|
||||
use std::ffi::c_void;
|
||||
use tfhe_cuda_backend::bindings::*;
|
||||
use tfhe_cuda_backend::cuda_bind::*;
|
||||
use tfhe_cuda_backend::ffi;
|
||||
|
||||
pub struct CudaStreams {
|
||||
pub ptr: Vec<*mut c_void>,
|
||||
@@ -477,13 +478,53 @@ pub fn get_programmable_bootstrap_multi_bit_size_on_gpu(
|
||||
size_tracker
|
||||
}
|
||||
|
||||
/// Keyswitch on a vector of LWE ciphertexts
|
||||
/// Keyswitch on a vector of LWE ciphertexts using the GEMM batch KS approach
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn keyswitch_async_gemm<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut CudaVec<T>,
|
||||
lwe_out_indexes: &CudaVec<T>,
|
||||
lwe_array_in: &CudaVec<T>,
|
||||
lwe_in_indexes: &CudaVec<T>,
|
||||
input_lwe_dimension: LweDimension,
|
||||
output_lwe_dimension: LweDimension,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
base_log: DecompositionBaseLog,
|
||||
l_gadget: DecompositionLevelCount,
|
||||
num_samples: u32,
|
||||
ks_tmp_buffer: *const ffi::c_void,
|
||||
uses_trivial_indices: bool,
|
||||
) {
|
||||
cuda_keyswitch_gemm_lwe_ciphertext_vector_64(
|
||||
streams.ptr[0],
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array_out.as_mut_c_ptr(0),
|
||||
lwe_out_indexes.as_c_ptr(0),
|
||||
lwe_array_in.as_c_ptr(0),
|
||||
lwe_in_indexes.as_c_ptr(0),
|
||||
keyswitch_key.as_c_ptr(0),
|
||||
input_lwe_dimension.0 as u32,
|
||||
output_lwe_dimension.0 as u32,
|
||||
base_log.0 as u32,
|
||||
l_gadget.0 as u32,
|
||||
num_samples,
|
||||
ks_tmp_buffer,
|
||||
uses_trivial_indices,
|
||||
);
|
||||
}
|
||||
|
||||
/// Keyswitch on a vector of LWE ciphertexts. Better for small batches of LWEs
|
||||
/// (up to 128 LWEs on H100, up to 64 on L40, up to 16 on 4090)
|
||||
/// # Safety
|
||||
///
|
||||
/// [CudaStreams::synchronize] __must__ be called as soon as synchronization is
|
||||
/// required
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub unsafe fn keyswitch_async<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut CudaVec<T>,
|
||||
@@ -512,7 +553,6 @@ pub unsafe fn keyswitch_async<T: UnsignedInteger>(
|
||||
num_samples,
|
||||
);
|
||||
}
|
||||
|
||||
/// Convert keyswitch key
|
||||
///
|
||||
/// # Safety
|
||||
|
||||
Reference in New Issue
Block a user