Compare commits

...

1 Commits

Author SHA1 Message Date
Pedro Alves
5df67f7666 refactor(gpu): mono-kernel TBC 2024-08-16 15:45:09 +00:00
2 changed files with 122 additions and 106 deletions

View File

@@ -32,43 +32,13 @@ __device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
}
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
__device__ void compute_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_in,
double2 *__restrict__ keybundle,
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t keybundle_size_per_input, int8_t *device_mem,
uint64_t device_memory_size_per_block) {
extern __shared__ int8_t sharedmem[];
int8_t *selected_memory = sharedmem;
if constexpr (SMD == FULLSM) {
selected_memory = sharedmem;
} else {
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
blockIdx.z * gridDim.x * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}
// Ids
uint32_t level_id = blockIdx.z;
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
uint32_t poly_id = blockIdx.y % (glwe_dimension + 1);
uint32_t lwe_iteration = (blockIdx.x % lwe_chunk_size + lwe_offset);
uint32_t input_idx = blockIdx.x / lwe_chunk_size;
if (lwe_iteration < (lwe_dimension / grouping_factor)) {
//
Torus *accumulator = (Torus *)selected_memory;
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
double2 *keybundle = keybundle_array +
// select the input
input_idx * keybundle_size_per_input;
uint32_t level_count, uint32_t lwe_chunk_size,
uint32_t level_id, uint32_t glwe_id, uint32_t poly_id, uint32_t chunk_id, uint32_t lwe_iteration, Torus *accumulator){
////////////////////////////////////////////////////////////
// Computes all keybundles
@@ -96,7 +66,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
// Calculates the monomial degree
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
lwe_in + rev_lwe_iteration * grouping_factor;
uint32_t monomial_degree = calculates_monomial_degree<Torus, params>(
lwe_array_group, g, grouping_factor);
@@ -124,7 +94,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
synchronize_threads_in_block();
// Move from local memory back to shared memory but as complex
tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
double2 *fft = (double2 *)accumulator;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = temp[i];
@@ -135,12 +105,69 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
// lwe iteration
auto keybundle_out = get_ith_mask_kth_block(
keybundle, blockIdx.x % lwe_chunk_size, glwe_id, level_id,
keybundle, chunk_id, glwe_id, level_id,
polynomial_size, glwe_dimension, level_count);
auto keybundle_poly = keybundle_out + poly_id * params::degree / 2;
copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
fft, keybundle_poly);
}
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t keybundle_size_per_input, int8_t *device_mem,
uint64_t device_memory_size_per_block) {
extern __shared__ int8_t sharedmem[];
int8_t *selected_memory = sharedmem;
if constexpr (SMD == FULLSM) {
selected_memory = sharedmem;
} else {
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
blockIdx.z * gridDim.x * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}
// Ids
uint32_t level_id = blockIdx.z;
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
uint32_t poly_id = blockIdx.y % (glwe_dimension + 1);
uint32_t lwe_iteration = (blockIdx.x % lwe_chunk_size + lwe_offset);
uint32_t input_idx = blockIdx.x / lwe_chunk_size;
uint32_t chunk_id = blockIdx.x % lwe_chunk_size;
if (lwe_iteration < (lwe_dimension / grouping_factor)) {
//
Torus *accumulator = (Torus *)selected_memory;
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
double2 *keybundle = keybundle_array +
// select the input
input_idx * keybundle_size_per_input;
compute_multi_bit_programmable_bootstrap_keybundle<Torus, params, SMD>(block_lwe_array_in,
keybundle,
bootstrapping_key,
lwe_dimension,
glwe_dimension,
polynomial_size,
grouping_factor,
level_count,
lwe_chunk_size,
level_id, glwe_id,
poly_id,
chunk_id,
lwe_iteration,
accumulator);
}
}

View File

@@ -19,17 +19,18 @@
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void __launch_bounds__(params::degree / params::opt)
device_multi_bit_programmable_bootstrap_tbc_accumulate(
device_multi_bit_programmable_bootstrap_tbc(
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
const Torus *__restrict__ bootstrapping_key,
double2 *__restrict__ keybundle_array, double2 *join_buffer,
Torus *global_accumulator, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
uint32_t level_count, uint32_t grouping_factor,
uint32_t keybundle_size_per_input,
int8_t *device_mem, uint64_t device_memory_size_per_block,
bool support_dsm) {
@@ -53,7 +54,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}
Torus *accumulator = (Torus *)selected_memory;
Torus *keybundle_accumulator = (Torus *)selected_memory;
Torus *accumulator = keybundle_accumulator + polynomial_size;
double2 *accumulator_fft =
(double2 *)accumulator +
(ptrdiff_t)(sizeof(Torus) * polynomial_size / sizeof(double2));
@@ -66,7 +68,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
const Torus *block_lwe_array_in =
const Torus *__restrict__ block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z] * (lwe_dimension + 1)];
const Torus *block_lut_vector =
@@ -81,11 +83,10 @@ __global__ void __launch_bounds__(params::degree / params::opt)
global_accumulator +
(blockIdx.y + blockIdx.z * (glwe_dimension + 1)) * params::degree;
const double2 *keybundle = keybundle_array +
double2 *__restrict__ keybundle = keybundle_array +
// select the input
blockIdx.z * keybundle_size_per_input;
if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Torus b_hat = 0;
modulus_switch(block_lwe_array_in[lwe_dimension], b_hat,
@@ -95,13 +96,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
params::degree / params::opt>(
accumulator, &block_lut_vector[blockIdx.y * params::degree], b_hat,
false);
} else {
// Load the accumulator calculated in previous iterations
copy_polynomial<Torus, params::opt, params::degree / params::opt>(
global_slice, accumulator);
}
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
for (int i = 0; i < lwe_dimension / grouping_factor; i++) {
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
@@ -119,15 +115,26 @@ __global__ void __launch_bounds__(params::degree / params::opt)
// don't modify the same memory space at the same time
synchronize_threads_in_block();
// Computes keybundle
for(int poly_id = 0; poly_id < glwe_dimension+1; poly_id++){
compute_multi_bit_programmable_bootstrap_keybundle<Torus, params, SMD>(
block_lwe_array_in,
keybundle,
bootstrapping_key,
lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, level_count, (uint32_t)1,
(uint32_t)blockIdx.x,(uint32_t)blockIdx.y, (uint32_t)poly_id, (uint32_t)0, (uint32_t)i, keybundle_accumulator);
cluster.sync(); synchronize_threads_in_block();
}
// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe<Torus, cluster_group, params>(
accumulator, accumulator_fft, block_join_buffer, keybundle,
polynomial_size, glwe_dimension, level_count, i, cluster, support_dsm);
polynomial_size, glwe_dimension, level_count, 0, cluster, support_dsm);
synchronize_threads_in_block();
}
if (lwe_offset + lwe_chunk_size >= (lwe_dimension / grouping_factor)) {
auto block_lwe_array_out =
&lwe_array_out[lwe_output_indexes[blockIdx.z] *
(glwe_dimension * polynomial_size + 1) +
@@ -141,11 +148,6 @@ __global__ void __launch_bounds__(params::degree / params::opt)
} else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) {
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator, 0);
}
} else {
// Load the accumulator calculated in previous iterations
copy_polynomial<Torus, params::opt, params::degree / params::opt>(
accumulator, global_slice);
}
}
template <typename Torus>
@@ -157,12 +159,13 @@ uint64_t get_buffer_size_sm_dsm_plus_tbc_multibit_programmable_bootstrap(
template <typename Torus>
uint64_t get_buffer_size_partial_sm_tbc_multibit_programmable_bootstrap(
uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size; // accumulator
return sizeof(Torus) * polynomial_size; // accumulator
}
template <typename Torus>
uint64_t get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap(
uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size * 2; // accumulator
return sizeof(Torus) * polynomial_size * 2+ // accumulator
sizeof(Torus) * polynomial_size; // keybundle accumulator
}
template <typename Torus, typename params>
@@ -217,35 +220,35 @@ __host__ void scratch_tbc_multi_bit_programmable_bootstrap(
if (max_shared_memory <
partial_sm_tbc_accumulate + minimum_sm_tbc_accumulate) {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
NOSM>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
minimum_sm_tbc_accumulate));
cudaFuncSetCacheConfig(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
NOSM>,
cudaFuncCachePreferShared);
check_cuda_error(cudaGetLastError());
} else if (max_shared_memory <
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate) {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
PARTIALSM>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
partial_sm_tbc_accumulate + minimum_sm_tbc_accumulate));
cudaFuncSetCacheConfig(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
PARTIALSM>,
cudaFuncCachePreferShared);
check_cuda_error(cudaGetLastError());
} else {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
FULLSM>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate));
cudaFuncSetCacheConfig(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
FULLSM>,
cudaFuncCachePreferShared);
check_cuda_error(cudaGetLastError());
@@ -260,14 +263,13 @@ __host__ void scratch_tbc_multi_bit_programmable_bootstrap(
}
template <typename Torus, class params>
__host__ void execute_tbc_external_product_loop(
__host__ void execute_tbc(
cudaStream_t stream, uint32_t gpu_index, Torus *lut_vector,
Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
Torus *lwe_array_out, Torus *lwe_output_indexes,
Torus *lwe_array_out, Torus *lwe_output_indexes,Torus *bootstrapping_key,
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, int lwe_offset) {
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count) {
auto supports_dsm =
supports_distributed_shared_memory_on_multibit_programmable_bootstrap<
@@ -289,12 +291,9 @@ __host__ void execute_tbc_external_product_loop(
cudaSetDevice(gpu_index);
uint32_t keybundle_size_per_input =
lwe_chunk_size * level_count * (glwe_dimension + 1) *
level_count * (glwe_dimension + 1) *
(glwe_dimension + 1) * (polynomial_size / 2);
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
auto d_mem = buffer->d_mem_acc_tbc;
auto keybundle_fft = buffer->keybundle_fft;
auto global_accumulator = buffer->global_accumulator;
@@ -323,35 +322,35 @@ __host__ void execute_tbc_external_product_loop(
config.dynamicSmemBytes = minimum_dm;
check_cuda_error(cudaLaunchKernelEx(
&config,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
NOSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
lwe_array_in, lwe_input_indexes, bootstrapping_key, keybundle_fft, buffer_fft,
global_accumulator, lwe_dimension, glwe_dimension, polynomial_size,
base_log, level_count, grouping_factor, lwe_offset, chunk_size,
base_log, level_count, grouping_factor,
keybundle_size_per_input, d_mem, full_dm, supports_dsm));
} else if (max_shared_memory < full_dm + minimum_dm) {
config.dynamicSmemBytes = partial_dm + minimum_dm;
check_cuda_error(cudaLaunchKernelEx(
&config,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
PARTIALSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
lwe_array_in, lwe_input_indexes, bootstrapping_key, keybundle_fft, buffer_fft,
global_accumulator, lwe_dimension, glwe_dimension, polynomial_size,
base_log, level_count, grouping_factor, lwe_offset, chunk_size,
keybundle_size_per_input, d_mem, partial_dm, supports_dsm));
base_log, level_count, grouping_factor,
keybundle_size_per_input, d_mem, full_dm, supports_dsm));
} else {
config.dynamicSmemBytes = full_dm + minimum_dm;
check_cuda_error(cudaLaunchKernelEx(
&config,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
lwe_array_in, lwe_input_indexes, bootstrapping_key, keybundle_fft, buffer_fft,
global_accumulator, lwe_dimension, glwe_dimension, polynomial_size,
base_log, level_count, grouping_factor, lwe_offset, chunk_size,
keybundle_size_per_input, d_mem, 0, supports_dsm));
base_log, level_count, grouping_factor,
keybundle_size_per_input, d_mem, full_dm, supports_dsm));
}
}
@@ -368,22 +367,12 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap(
auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {
// Compute a keybundle
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
// Accumulate
execute_tbc_external_product_loop<Torus, params>(
execute_tbc<Torus, params>(
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
lwe_input_indexes, lwe_array_out, lwe_output_indexes, bootstrapping_key,
buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
}
grouping_factor, base_log, level_count);
}
template <typename Torus>
@@ -446,33 +435,33 @@ __host__ bool supports_thread_block_clusters_on_multibit_programmable_bootstrap(
if (max_shared_memory <
partial_sm_tbc_accumulate + minimum_sm_tbc_accumulate) {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
NOSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
NOSM>,
&config));
} else if (max_shared_memory <
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate) {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
PARTIALSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
PARTIALSM>,
&config));
} else {
check_cuda_error(cudaFuncSetAttribute(
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
FULLSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
device_multi_bit_programmable_bootstrap_tbc_accumulate<Torus, params,
device_multi_bit_programmable_bootstrap_tbc<Torus, params,
FULLSM>,
&config));
}