mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
feat(gpu): make PBS and ks execution parallel over available GPUs
Only GPUs with peer access to GPU 0 can be used for this at the moment. Peer to peer copy is used if different GPUs are passed to memcpy_gpu_to_gpu A gpu offset is passed as new parameter to pbs and keyswitch to adjust the input/output index user per gpu. bsk and ksk are copied to all GPUs. The CI now tests & runs benchmarks on p3.8xlarge aws instances
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "device.h"
|
||||
#include "gadget.cuh"
|
||||
#include "helper_multi_gpu.h"
|
||||
#include "polynomial/functions.cuh"
|
||||
#include "polynomial/polynomial_math.cuh"
|
||||
#include "torus.cuh"
|
||||
@@ -37,23 +38,26 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
// threads in y are used to paralelize the lwe_dimension_in loop.
|
||||
// shared memory is used to store intermediate results of the reduction.
|
||||
template <typename Torus>
|
||||
__global__ void
|
||||
keyswitch(Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lwe_array_in,
|
||||
Torus *lwe_input_indexes, Torus *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
|
||||
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_output_indexes,
|
||||
Torus *lwe_array_in, Torus *lwe_input_indexes,
|
||||
Torus *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, int gpu_offset) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
Torus *lwe_acc_out = (Torus *)sharedmem;
|
||||
auto block_lwe_array_out = get_chunk(
|
||||
lwe_array_out, lwe_output_indexes[blockIdx.y], lwe_dimension_out + 1);
|
||||
auto block_lwe_array_out =
|
||||
get_chunk(lwe_array_out, lwe_output_indexes[blockIdx.y + gpu_offset],
|
||||
lwe_dimension_out + 1);
|
||||
|
||||
if (tid <= lwe_dimension_out) {
|
||||
|
||||
Torus local_lwe_out = 0;
|
||||
auto block_lwe_array_in = get_chunk(
|
||||
lwe_array_in, lwe_input_indexes[blockIdx.y], lwe_dimension_in + 1);
|
||||
auto block_lwe_array_in =
|
||||
get_chunk(lwe_array_in, lwe_input_indexes[blockIdx.y + gpu_offset],
|
||||
lwe_dimension_in + 1);
|
||||
|
||||
if (tid == lwe_dimension_out && threadIdx.y == 0) {
|
||||
local_lwe_out = block_lwe_array_in[lwe_dimension_in];
|
||||
@@ -99,7 +103,8 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
Torus *lwe_output_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes,
|
||||
Torus *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t gpu_offset = 0) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
@@ -115,8 +120,42 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
|
||||
keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(
|
||||
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, gpu_offset);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
void execute_keyswitch(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *lwe_array_out,
|
||||
Torus *lwe_output_indexes, Torus *lwe_array_in,
|
||||
Torus *lwe_input_indexes, Torus **ksks,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, bool sync_streams = true) {
|
||||
|
||||
/// If the number of radix blocks is lower than the number of GPUs, not all
|
||||
/// GPUs will be active and there will be 1 input per GPU
|
||||
auto active_gpu_count = get_active_gpu_count(num_samples, gpu_count);
|
||||
int num_samples_on_gpu_0 = get_num_inputs_on_gpu(num_samples, 0, gpu_count);
|
||||
if (sync_streams)
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
#pragma omp parallel for num_threads(active_gpu_count)
|
||||
for (uint i = 0; i < active_gpu_count; i++) {
|
||||
int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count);
|
||||
int gpu_offset = get_gpu_offset(num_samples, i, gpu_count);
|
||||
|
||||
// Compute Keyswitch
|
||||
cuda_keyswitch_lwe_ciphertext_vector<Torus>(
|
||||
streams[i], gpu_indexes[i], lwe_array_out, lwe_output_indexes,
|
||||
lwe_array_in, lwe_input_indexes, ksks[i], lwe_dimension_in,
|
||||
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
|
||||
gpu_offset);
|
||||
}
|
||||
|
||||
if (sync_streams)
|
||||
for (uint i = 0; i < active_gpu_count; i++) {
|
||||
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user