chore(gpu): change the number of threads used in the keyswitch

This commit is contained in:
Agnes Leroy
2024-05-22 15:50:06 +02:00
committed by Pedro Alves
parent fadb48a86d
commit fedd1ca7b2
3 changed files with 34 additions and 142 deletions

View File

@@ -20,9 +20,7 @@ private:
uint32_t level_count;
uint32_t base_log;
uint32_t mask;
uint32_t halfbg;
uint32_t num_poly;
T offset;
int current_level;
T mask_mod_b;
T *state;
@@ -82,72 +80,12 @@ public:
synchronize_threads_in_block();
}
// Decomposes a single polynomial
__device__ void
decompose_and_compress_next_polynomial_elements(double2 *result, int j) {
if (j == 0)
current_level -= 1;
int tid = threadIdx.x;
auto state_slice = state + j * params::degree;
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state_slice[tid] & mask_mod_b;
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
state_slice[tid] >>= base_log;
state_slice[tid + params::degree / 2] >>= base_log;
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
T carry_im =
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state_slice[tid] += carry_re;
state_slice[tid + params::degree / 2] += carry_im;
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;
result[i].x = (int32_t)res_re;
result[i].y = (int32_t)res_im;
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}
__device__ void decompose_and_compress_level(double2 *result, int level) {
for (int i = 0; i < level_count - level; i++)
decompose_and_compress_next(result);
}
};
template <typename T> class GadgetMatrixSingle {
private:
uint32_t level_count;
uint32_t base_log;
uint32_t mask;
uint32_t halfbg;
T offset;
public:
__device__ GadgetMatrixSingle(uint32_t base_log, uint32_t level_count)
: base_log(base_log), level_count(level_count) {
uint32_t bg = 1 << base_log;
this->halfbg = bg / 2;
this->mask = bg - 1;
T temp = 0;
for (int i = 0; i < this->level_count; i++) {
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
}
this->offset = temp * this->halfbg;
}
__device__ T decompose_one_level_single(T element, uint32_t level) {
T s = element + this->offset;
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
T temp1 = (s >> decal) & this->mask;
return (T)(temp1 - this->halfbg);
}
};
template <typename Torus>
__device__ Torus decompose_one(Torus &state, Torus mask_mod_b, int base_log) {
Torus res = state & mask_mod_b;

View File

@@ -35,63 +35,35 @@ template <typename Torus>
__global__ void
keyswitch(Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lwe_array_in,
Torus *lwe_input_indexes, Torus *ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
int lwe_lower, int lwe_upper, int cutoff) {
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
int tid = threadIdx.x;
extern __shared__ int8_t sharedmem[];
if (tid <= lwe_dimension_out) {
Torus *local_lwe_array_out = (Torus *)sharedmem;
auto block_lwe_array_in = get_chunk(
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
auto block_lwe_array_out = get_chunk(
lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1);
local_lwe_array_out[tid] = 0;
Torus *local_lwe_array_out = (Torus *)sharedmem;
if (tid == lwe_dimension_out) {
local_lwe_array_out[lwe_dimension_out] =
block_lwe_array_in[lwe_dimension_in];
}
auto block_lwe_array_in = get_chunk(
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
auto block_lwe_array_out = get_chunk(
lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1);
auto gadget = GadgetMatrixSingle<Torus>(base_log, level_count);
int lwe_part_per_thd;
if (tid < cutoff) {
lwe_part_per_thd = lwe_upper;
} else {
lwe_part_per_thd = lwe_lower;
}
__syncthreads();
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_array_out[idx] = 0;
}
__syncthreads();
if (tid == 0) {
local_lwe_array_out[lwe_dimension_out] =
block_lwe_array_in[lwe_dimension_in];
}
for (int i = 0; i < lwe_dimension_in; i++) {
__syncthreads();
Torus a_i =
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mask_mod_b = (1ll << base_log) - 1ll;
for (int j = 0; j < level_count; j++) {
auto ksk_block = get_ith_block(ksk, i, j, lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
for (int i = 0; i < lwe_dimension_in; i++) {
Torus a_i = round_to_closest_multiple(block_lwe_array_in[i], base_log,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mask_mod_b = (1ll << base_log) - 1ll;
for (int j = 0; j < level_count; j++) {
auto ksk_block =
get_ith_block(ksk, i, j, lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
local_lwe_array_out[tid] -= (Torus)ksk_block[tid] * decomposed;
}
}
}
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
block_lwe_array_out[idx] = local_lwe_array_out[idx];
block_lwe_array_out[tid] = local_lwe_array_out[tid];
}
}
@@ -104,37 +76,19 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
cudaSetDevice(gpu_index);
constexpr int ideal_threads = 128;
constexpr int ideal_threads = 1024;
if (lwe_dimension_out + 1 > ideal_threads)
PANIC("Cuda error (keyswitch): lwe dimension size out should be greater "
"or equal to the number of threads per block")
int lwe_size = lwe_dimension_out + 1;
int lwe_lower, lwe_upper, cutoff;
if (lwe_size % ideal_threads == 0) {
lwe_lower = lwe_size / ideal_threads;
lwe_upper = lwe_size / ideal_threads;
cutoff = 0;
} else {
int y = ceil((double)lwe_size / (double)ideal_threads) * ideal_threads -
lwe_size;
cutoff = ideal_threads - y;
lwe_lower = lwe_size / ideal_threads;
lwe_upper = (int)ceil((double)lwe_size / (double)ideal_threads);
}
int lwe_size_after = lwe_size * num_samples;
int shared_mem = sizeof(Torus) * lwe_size;
cuda_memset_async(lwe_array_out, 0, sizeof(Torus) * lwe_size_after, stream,
gpu_index);
check_cuda_error(cudaGetLastError());
dim3 grid(num_samples, 1, 1);
dim3 threads(ideal_threads, 1, 1);
keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
lwe_dimension_in, lwe_dimension_out, base_log, level_count, lwe_lower,
lwe_upper, cutoff);
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
check_cuda_error(cudaGetLastError());
}

View File

@@ -132,22 +132,22 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
// n, k*N, noise_distribution, ks_base_log, ks_level,
// message_modulus, carry_modulus, number_of_inputs
(KeyswitchTestParams){
567, 1280, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
1280, 567, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
3, 3, 2, 1, 10},
(KeyswitchTestParams){
694, 1536, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
1536, 694, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
4, 3, 2, 1, 10},
(KeyswitchTestParams){
769, 2048, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
2048, 769, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
4, 3, 2, 1, 10},
(KeyswitchTestParams){
754, 2048, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
2048, 754, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
3, 5, 2, 1, 10},
(KeyswitchTestParams){742, 2048,
(KeyswitchTestParams){2048, 742,
new_gaussian_from_std_dev(sqrt(4.9982771e-11)), 3,
5, 4, 1, 10},
(KeyswitchTestParams){
847, 4096, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
4096, 847, new_gaussian_from_std_dev(sqrt(2.9802322387695312e-18)),
4, 4, 2, 1, 10});
std::string printParamName(::testing::TestParamInfo<KeyswitchTestParams> p) {