mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix default multi-bit PBS for multi-device execution of integer ops
This commit is contained in:
@@ -114,6 +114,8 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
|
||||
uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size,
|
||||
PBS_VARIANT pbs_variant, bool allocate_gpu_memory) {
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
this->pbs_variant = pbs_variant;
|
||||
this->lwe_chunk_size = lwe_chunk_size;
|
||||
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
|
||||
@@ -215,6 +215,8 @@ __host__ void scratch_cg_multi_bit_programmable_bootstrap(
|
||||
uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
uint64_t full_sm_keybundle =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
|
||||
polynomial_size);
|
||||
@@ -296,6 +298,7 @@ __host__ void execute_cg_external_product_loop(
|
||||
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t lwe_offset, uint32_t num_many_lut,
|
||||
uint32_t lut_stride) {
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
uint64_t full_sm =
|
||||
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
|
||||
@@ -310,7 +313,6 @@ __host__ void execute_cg_external_product_loop(
|
||||
|
||||
auto lwe_chunk_size = buffer->lwe_chunk_size;
|
||||
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
uint32_t keybundle_size_per_input =
|
||||
lwe_chunk_size * level_count * (glwe_dimension + 1) *
|
||||
|
||||
@@ -388,6 +388,8 @@ __host__ void scratch_multi_bit_programmable_bootstrap(
|
||||
uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
uint64_t full_sm_keybundle =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
|
||||
@@ -494,6 +496,7 @@ __host__ void execute_compute_keybundle(
|
||||
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) {
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
auto lwe_chunk_size = buffer->lwe_chunk_size;
|
||||
uint32_t chunk_size =
|
||||
@@ -507,7 +510,6 @@ __host__ void execute_compute_keybundle(
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
|
||||
polynomial_size);
|
||||
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
auto d_mem = buffer->d_mem_keybundle;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
@@ -543,6 +545,7 @@ execute_step_one(cudaStream_t stream, uint32_t gpu_index,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t j, uint32_t lwe_offset) {
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
uint64_t full_sm_accumulate_step_one =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one<Torus>(
|
||||
@@ -551,7 +554,6 @@ execute_step_one(cudaStream_t stream, uint32_t gpu_index,
|
||||
get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one<
|
||||
Torus>(polynomial_size);
|
||||
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
//
|
||||
auto d_mem = buffer->d_mem_acc_step_one;
|
||||
@@ -599,13 +601,13 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
uint32_t polynomial_size, int32_t grouping_factor,
|
||||
uint32_t level_count, uint32_t j, uint32_t lwe_offset,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
auto lwe_chunk_size = buffer->lwe_chunk_size;
|
||||
uint64_t full_sm_accumulate_step_two =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
|
||||
polynomial_size);
|
||||
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
|
||||
cudaSetDevice(gpu_index);
|
||||
|
||||
auto d_mem = buffer->d_mem_acc_step_two;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
|
||||
Reference in New Issue
Block a user