diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index 1fdcaac26..bc2b742bb 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -395,14 +395,14 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec, uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, - bool allocate_gpu_memory, bool allocate_ms_array); + bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory, + bool allocate_ms_array); void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, CudaRadixCiphertextFFI *radix_lwe_out, - CudaRadixCiphertextFFI *radix_lwe_vec, - bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr, - void *const *bsks, void *const *ksks, + CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks, + void *const *ksks, CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key); void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec( diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 62d5623a9..875c45078 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -10,6 +10,7 @@ #include "utils/helper_multi_gpu.cuh" #include #include +#include class NoiseLevel { public: @@ -82,6 +83,165 @@ void generate_many_lut_device_accumulator( uint32_t message_modulus, uint32_t carry_modulus, std::vector> &f, bool gpu_memory_allocated); +template struct radix_columns { + std::vector> columns; + std::vector columns_counter; + std::vector> new_columns; + std::vector new_columns_counter; + + uint32_t num_blocks; + uint32_t num_radix_in_vec; + uint32_t chunk_size; + radix_columns(const uint64_t *const input_degrees, uint32_t num_blocks, + uint32_t num_radix_in_vec, uint32_t chunk_size, + bool &needs_processing) + : num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec), + chunk_size(chunk_size) { + needs_processing = false; + columns.resize(num_blocks); + columns_counter.resize(num_blocks, 0); + new_columns.resize(num_blocks); + new_columns_counter.resize(num_blocks, 0); + for (uint32_t i = 0; i < num_blocks; ++i) { + new_columns[i].resize(num_radix_in_vec); + } + for (uint32_t i = 0; i < num_radix_in_vec; ++i) { + for (uint32_t j = 0; j < num_blocks; ++j) { + if (input_degrees[i * num_blocks + j]) { + columns[j].push_back(i * num_blocks + j); + columns_counter[j]++; + } + } + } + for (uint32_t i = 0; i < num_blocks; ++i) { + if (columns_counter[i] > chunk_size) { + needs_processing = true; + break; + } + } + } + + void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out, + Torus *h_lut_indexes, uint32_t &total_ciphertexts, + uint32_t &message_ciphertexts, + bool &needs_processing) { + message_ciphertexts = 0; + total_ciphertexts = 0; + needs_processing = false; + + uint32_t pbs_count = 0; + for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) { + const uint32_t column_len = columns_counter[c_id]; + new_columns_counter[c_id] = 0; + uint32_t ct_count = 0; + // add message cts into new columns + for (uint32_t i = 0; i + chunk_size <= column_len; i += chunk_size) { + const Torus in_index = columns[c_id][i]; + new_columns[c_id][ct_count] = in_index; + if (h_indexes_in != nullptr) + h_indexes_in[pbs_count] = in_index; + if (h_indexes_out != nullptr) + h_indexes_out[pbs_count] = in_index; + if (h_lut_indexes != nullptr) + h_lut_indexes[pbs_count] = 0; + ++pbs_count; + ++ct_count; + ++message_ciphertexts; + } + new_columns_counter[c_id] = ct_count; + } + + for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) { + const uint32_t column_len = columns_counter[c_id]; + uint32_t ct_count = new_columns_counter[c_id]; + // add carry cts into new columns + if (c_id > 0) { + const uint32_t prev_c_id = c_id - 1; + const uint32_t prev_column_len = columns_counter[prev_c_id]; + for (uint32_t i = 0; i + chunk_size <= prev_column_len; + i += chunk_size) { + const Torus in_index = columns[prev_c_id][i]; + const Torus out_index = columns[prev_c_id][i + 1]; + new_columns[c_id][ct_count] = out_index; + if (h_indexes_in != nullptr) + h_indexes_in[pbs_count] = in_index; + if (h_indexes_out != nullptr) + h_indexes_out[pbs_count] = out_index; + if (h_lut_indexes != nullptr) + h_lut_indexes[pbs_count] = 1; + ++pbs_count; + ++ct_count; + } + } + // add remaining cts into new columns + const uint32_t start_index = column_len - column_len % chunk_size; + for (uint32_t i = start_index; i < column_len; ++i) { + new_columns[c_id][ct_count] = columns[c_id][i]; + ++ct_count; + } + new_columns_counter[c_id] = ct_count; + if (ct_count > chunk_size) { + needs_processing = true; + } + } + total_ciphertexts = pbs_count; + swap(columns, new_columns); + swap(columns_counter, new_columns_counter); + } + + void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out, + Torus *h_lut_indexes) { + for (uint32_t idx = 0; idx < 2 * num_blocks; ++idx) { + if (h_indexes_in != nullptr) + h_indexes_in[idx] = idx % num_blocks; + if (h_indexes_out != nullptr) + h_indexes_out[idx] = idx + idx / num_blocks; + if (h_lut_indexes != nullptr) + h_lut_indexes[idx] = idx / num_blocks; + } + } +}; + +inline void calculate_final_degrees(uint64_t *const out_degrees, + const uint64_t *const input_degrees, + uint32_t num_blocks, + uint32_t num_radix_in_vec, + uint32_t chunk_size, + uint64_t message_modulus) { + + auto get_degree = [message_modulus](uint64_t degree) -> uint64_t { + return std::min(message_modulus - 1, degree); + }; + std::vector> columns(num_blocks); + for (uint32_t i = 0; i < num_radix_in_vec; ++i) { + for (uint32_t j = 0; j < num_blocks; ++j) { + if (input_degrees[i * num_blocks + j]) + columns[j].push(input_degrees[i * num_blocks + j]); + } + } + + for (uint32_t i = 0; i < num_blocks; ++i) { + auto &col = columns[i]; + while (col.size() > 1) { + uint32_t cur_degree = 0; + uint32_t mn = std::min(chunk_size, (uint32_t)col.size()); + for (int j = 0; j < mn; ++j) { + cur_degree += col.front(); + col.pop(); + } + const uint64_t new_degree = get_degree(cur_degree); + col.push(new_degree); + if ((i + 1) < num_blocks) { + columns[i + 1].push(new_degree); + } + } + } + + for (int i = 0; i < num_blocks; i++) { + out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front(); + } +} + struct int_radix_params { PBS_TYPE pbs_type; uint32_t glwe_dimension; @@ -1325,8 +1485,8 @@ template struct int_sum_ciphertexts_vec_memory { uint32_t num_blocks_in_radix; uint32_t max_num_radix_in_vec; uint32_t chunk_size; - uint64_t *size_tracker; bool gpu_memory_allocated; + bool reduce_degrees_for_single_carry_propagation; // temporary buffers CudaRadixCiphertextFFI *current_blocks; @@ -1349,7 +1509,8 @@ template struct int_sum_ciphertexts_vec_memory { bool allocated_luts_message_carry; void setup_index_buffers(cudaStream_t const *streams, - uint32_t const *gpu_indexes) { + uint32_t const *gpu_indexes, + uint64_t *size_tracker) { d_degrees = (uint64_t *)cuda_malloc_with_size_tracking_async( max_total_blocks_in_vec * sizeof(uint64_t), streams[0], gpu_indexes[0], @@ -1394,18 +1555,29 @@ template struct int_sum_ciphertexts_vec_memory { } void setup_lookup_tables(cudaStream_t const *streams, - uint32_t const *gpu_indexes, uint32_t gpu_count) { + uint32_t const *gpu_indexes, uint32_t gpu_count, + uint32_t num_radix_in_vec, + const uint64_t *const degrees) { uint32_t message_modulus = params.message_modulus; + bool _needs_processing = false; + radix_columns current_columns(degrees, num_blocks_in_radix, + num_radix_in_vec, chunk_size, + _needs_processing); + uint32_t total_ciphertexts = 0; + uint32_t total_messages = 0; + current_columns.next_accumulation(nullptr, nullptr, nullptr, + total_ciphertexts, total_messages, + _needs_processing); if (!mem_reuse) { - uint32_t pbs_count = std::max(2 * (max_total_blocks_in_vec / chunk_size), - 2 * num_blocks_in_radix); - if (max_total_blocks_in_vec > 0) { - luts_message_carry = new int_radix_lut( - streams, gpu_indexes, gpu_count, params, 2, pbs_count, - gpu_memory_allocated, size_tracker); - } else { - allocated_luts_message_carry = false; + uint32_t pbs_count = std::max(total_ciphertexts, 2 * num_blocks_in_radix); + if (total_ciphertexts > 0 || + reduce_degrees_for_single_carry_propagation) { + uint64_t size_tracker = 0; + luts_message_carry = + new int_radix_lut(streams, gpu_indexes, gpu_count, params, 2, + pbs_count, true, &size_tracker); + allocated_luts_message_carry = true; } } if (allocated_luts_message_carry) { @@ -1436,25 +1608,35 @@ template struct int_sum_ciphertexts_vec_memory { luts_message_carry->broadcast_lut(streams, gpu_indexes, 0); } } - int_sum_ciphertexts_vec_memory(cudaStream_t const *streams, - uint32_t const *gpu_indexes, - uint32_t gpu_count, int_radix_params params, - uint32_t num_blocks_in_radix, - uint32_t max_num_radix_in_vec, - bool allocate_gpu_memory, - uint64_t *size_tracker) { + int_sum_ciphertexts_vec_memory( + cudaStream_t const *streams, uint32_t const *gpu_indexes, + uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix, + uint32_t max_num_radix_in_vec, + bool reduce_degrees_for_single_carry_propagation, + bool allocate_gpu_memory, uint64_t *size_tracker) { this->params = params; this->mem_reuse = false; this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec; this->num_blocks_in_radix = num_blocks_in_radix; this->max_num_radix_in_vec = max_num_radix_in_vec; this->gpu_memory_allocated = allocate_gpu_memory; - this->size_tracker = size_tracker; this->chunk_size = (params.message_modulus * params.carry_modulus - 1) / (params.message_modulus - 1); - this->allocated_luts_message_carry = true; - setup_index_buffers(streams, gpu_indexes); - setup_lookup_tables(streams, gpu_indexes, gpu_count); + this->allocated_luts_message_carry = false; + this->reduce_degrees_for_single_carry_propagation = + reduce_degrees_for_single_carry_propagation; + setup_index_buffers(streams, gpu_indexes, size_tracker); + // because we setup_lut in host function for sum_ciphertexts to save memory + // the size_tracker is topped up here to have a max bound on the used memory + uint32_t max_pbs_count = std::max( + 2 * (max_total_blocks_in_vec / chunk_size), 2 * num_blocks_in_radix); + if (max_pbs_count > 0) { + int_radix_lut *luts_message_carry_dry_run = + new int_radix_lut(streams, gpu_indexes, gpu_count, params, 2, + max_pbs_count, false, size_tracker); + luts_message_carry_dry_run->release(streams, gpu_indexes, gpu_count); + delete luts_message_carry_dry_run; + } // create and allocate intermediate buffers current_blocks = new CudaRadixCiphertextFFI; @@ -1472,23 +1654,25 @@ template struct int_sum_ciphertexts_vec_memory { uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec, CudaRadixCiphertextFFI *current_blocks, CudaRadixCiphertextFFI *small_lwe_vector, - int_radix_lut *reused_lut, bool allocate_gpu_memory, - uint64_t *size_tracker) { + int_radix_lut *reused_lut, + bool reduce_degrees_for_single_carry_propagation, + bool allocate_gpu_memory, uint64_t *size_tracker) { this->mem_reuse = true; this->params = params; this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec; this->num_blocks_in_radix = num_blocks_in_radix; this->max_num_radix_in_vec = max_num_radix_in_vec; this->gpu_memory_allocated = allocate_gpu_memory; - this->size_tracker = size_tracker; this->chunk_size = (params.message_modulus * params.carry_modulus - 1) / (params.message_modulus - 1); this->allocated_luts_message_carry = true; + this->reduce_degrees_for_single_carry_propagation = + reduce_degrees_for_single_carry_propagation; this->current_blocks = current_blocks; this->small_lwe_vector = small_lwe_vector; this->luts_message_carry = reused_lut; - setup_index_buffers(streams, gpu_indexes); + setup_index_buffers(streams, gpu_indexes, size_tracker); } void release(cudaStream_t const *streams, uint32_t const *gpu_indexes, @@ -1518,7 +1702,6 @@ template struct int_sum_ciphertexts_vec_memory { luts_message_carry->release(streams, gpu_indexes, gpu_count); delete luts_message_carry; } - delete current_blocks; delete small_lwe_vector; } @@ -2898,7 +3081,7 @@ template struct int_mul_memory { sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory( streams, gpu_indexes, gpu_count, params, num_radix_blocks, 2 * num_radix_blocks, vector_result_sb, small_lwe_vector, luts_array, - allocate_gpu_memory, size_tracker); + true, allocate_gpu_memory, size_tracker); uint32_t uses_carry = 0; uint32_t requested_flag = outputFlag::FLAG_NONE; sc_prop_mem = new int_sc_prop_memory( @@ -4694,10 +4877,11 @@ template struct int_scalar_mul_buffer { //// The idea is that with these we can create all other shift that are /// in / range (0..total_bits) for free (block rotation) preshifted_buffer = new CudaRadixCiphertextFFI; + uint64_t anticipated_drop_mem = 0; create_zero_radix_ciphertext_async( streams[0], gpu_indexes[0], preshifted_buffer, - msg_bits * num_radix_blocks, params.big_lwe_dimension, size_tracker, - allocate_gpu_memory); + msg_bits * num_radix_blocks, params.big_lwe_dimension, + &anticipated_drop_mem, allocate_gpu_memory); all_shifted_buffer = new CudaRadixCiphertextFFI; create_zero_radix_ciphertext_async( @@ -4708,22 +4892,28 @@ template struct int_scalar_mul_buffer { if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2) logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer( streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks, - allocate_gpu_memory, all_shifted_buffer, size_tracker); + allocate_gpu_memory, all_shifted_buffer, &anticipated_drop_mem); else logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer( streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks, - allocate_gpu_memory, size_tracker); + allocate_gpu_memory, &anticipated_drop_mem); + uint64_t last_step_mem = 0; if (num_ciphertext_bits > 0) { sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory( streams, gpu_indexes, gpu_count, params, num_radix_blocks, - num_ciphertext_bits, allocate_gpu_memory, size_tracker); + num_ciphertext_bits, true, allocate_gpu_memory, &last_step_mem); } uint32_t uses_carry = 0; uint32_t requested_flag = outputFlag::FLAG_NONE; sc_prop_mem = new int_sc_prop_memory( streams, gpu_indexes, gpu_count, params, num_radix_blocks, - requested_flag, uses_carry, allocate_gpu_memory, size_tracker); + requested_flag, uses_carry, allocate_gpu_memory, &last_step_mem); + if (anticipated_buffer_drop) { + *size_tracker += std::max(anticipated_drop_mem, last_step_mem); + } else { + *size_tracker += anticipated_drop_mem + last_step_mem; + } } void release(cudaStream_t const *streams, uint32_t const *gpu_indexes, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu index 5a09f2a71..b664ac338 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu @@ -210,7 +210,8 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec, uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, - bool allocate_gpu_memory, bool allocate_ms_array) { + bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory, + bool allocate_ms_array) { int_radix_params params(pbs_type, glwe_dimension, polynomial_size, glwe_dimension * polynomial_size, lwe_dimension, @@ -220,15 +221,15 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( return scratch_cuda_integer_partial_sum_ciphertexts_vec_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, (int_sum_ciphertexts_vec_memory **)mem_ptr, num_blocks_in_radix, - max_num_radix_in_vec, params, allocate_gpu_memory); + max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation, params, + allocate_gpu_memory); } void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, CudaRadixCiphertextFFI *radix_lwe_out, - CudaRadixCiphertextFFI *radix_lwe_vec, - bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr, - void *const *bsks, void *const *ksks, + CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks, + void *const *ksks, CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key) { auto mem = (int_sum_ciphertexts_vec_memory *)mem_ptr; @@ -239,8 +240,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( case 512: host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; @@ -248,8 +248,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; @@ -257,8 +256,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; @@ -266,8 +264,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; @@ -275,8 +272,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; @@ -284,8 +280,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out, - radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks, - (uint64_t **)(ksks), ms_noise_reduction_key, mem, + radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem, radix_lwe_out->num_radix_blocks, radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks); break; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index 9e0b3ad75..ef3864119 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -273,164 +272,19 @@ __global__ void fill_radix_from_lsb_msb(Torus *result_blocks, Torus *lsb_blocks, } } -template struct radix_columns { - std::vector> columns; - std::vector columns_counter; - std::vector> new_columns; - std::vector new_columns_counter; - - size_t num_blocks; - size_t num_radix_in_vec; - size_t chunk_size; - radix_columns(const uint64_t *const input_degrees, size_t num_blocks, - size_t num_radix_in_vec, size_t chunk_size, - bool &needs_processing) - : num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec), - chunk_size(chunk_size) { - needs_processing = false; - columns.resize(num_blocks); - columns_counter.resize(num_blocks, 0); - new_columns.resize(num_blocks); - new_columns_counter.resize(num_blocks, 0); - for (size_t i = 0; i < num_blocks; ++i) { - new_columns[i].resize(num_radix_in_vec); - } - for (size_t i = 0; i < num_radix_in_vec; ++i) { - for (size_t j = 0; j < num_blocks; ++j) { - if (input_degrees[i * num_blocks + j]) { - columns[j].push_back(i * num_blocks + j); - columns_counter[j]++; - } - } - } - for (size_t i = 0; i < num_blocks; ++i) { - if (columns_counter[i] > chunk_size) { - needs_processing = true; - break; - } - } - } - - void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out, - Torus *h_lut_indexes, size_t &total_ciphertexts, - size_t &message_ciphertexts, bool &needs_processing) { - message_ciphertexts = 0; - total_ciphertexts = 0; - needs_processing = false; - - size_t pbs_count = 0; - for (size_t c_id = 0; c_id < num_blocks; ++c_id) { - const size_t column_len = columns_counter[c_id]; - new_columns_counter[c_id] = 0; - size_t ct_count = 0; - // add message cts into new columns - for (size_t i = 0; i + chunk_size <= column_len; i += chunk_size) { - const Torus in_index = columns[c_id][i]; - new_columns[c_id][ct_count] = in_index; - h_indexes_in[pbs_count] = in_index; - h_indexes_out[pbs_count] = in_index; - h_lut_indexes[pbs_count] = 0; - ++pbs_count; - ++ct_count; - ++message_ciphertexts; - } - new_columns_counter[c_id] = ct_count; - } - - for (size_t c_id = 0; c_id < num_blocks; ++c_id) { - const size_t column_len = columns_counter[c_id]; - size_t ct_count = new_columns_counter[c_id]; - // add carry cts into new columns - if (c_id > 0) { - const size_t prev_c_id = c_id - 1; - const size_t prev_column_len = columns_counter[prev_c_id]; - for (size_t i = 0; i + chunk_size <= prev_column_len; i += chunk_size) { - const Torus in_index = columns[prev_c_id][i]; - const Torus out_index = columns[prev_c_id][i + 1]; - new_columns[c_id][ct_count] = out_index; - h_indexes_in[pbs_count] = in_index; - h_indexes_out[pbs_count] = out_index; - h_lut_indexes[pbs_count] = 1; - ++pbs_count; - ++ct_count; - } - } - // add remaining cts into new columns - const size_t start_index = column_len - column_len % chunk_size; - for (size_t i = start_index; i < column_len; ++i) { - new_columns[c_id][ct_count] = columns[c_id][i]; - ++ct_count; - } - new_columns_counter[c_id] = ct_count; - if (ct_count > chunk_size) { - needs_processing = true; - } - } - total_ciphertexts = pbs_count; - swap(columns, new_columns); - swap(columns_counter, new_columns_counter); - } - - void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out, - Torus *h_lut_indexes) { - for (size_t idx = 0; idx < 2 * num_blocks; ++idx) { - h_indexes_in[idx] = idx % num_blocks; - h_indexes_out[idx] = idx + idx / num_blocks; - h_lut_indexes[idx] = idx / num_blocks; - } - } -}; - -inline void calculate_final_degrees(uint64_t *const out_degrees, - const uint64_t *const input_degrees, - size_t num_blocks, size_t num_radix_in_vec, - size_t chunk_size, - uint64_t message_modulus) { - - auto get_degree = [message_modulus](uint64_t degree) -> uint64_t { - return std::min(message_modulus - 1, degree); - }; - std::vector> columns(num_blocks); - for (size_t i = 0; i < num_radix_in_vec; ++i) { - for (size_t j = 0; j < num_blocks; ++j) { - if (input_degrees[i * num_blocks + j]) - columns[j].push(input_degrees[i * num_blocks + j]); - } - } - - for (size_t i = 0; i < num_blocks; ++i) { - auto &col = columns[i]; - while (col.size() > 1) { - uint32_t cur_degree = 0; - size_t mn = std::min(chunk_size, col.size()); - for (int j = 0; j < mn; ++j) { - cur_degree += col.front(); - col.pop(); - } - const uint64_t new_degree = get_degree(cur_degree); - col.push(new_degree); - if ((i + 1) < num_blocks) { - columns[i + 1].push(new_degree); - } - } - } - - for (int i = 0; i < num_blocks; i++) { - out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front(); - } -} - template __host__ uint64_t scratch_cuda_integer_partial_sum_ciphertexts_vec_kb( cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, int_sum_ciphertexts_vec_memory **mem_ptr, uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec, - int_radix_params params, bool allocate_gpu_memory) { + bool reduce_degrees_for_single_carry_propagation, int_radix_params params, + bool allocate_gpu_memory) { uint64_t size_tracker = 0; *mem_ptr = new int_sum_ciphertexts_vec_memory( streams, gpu_indexes, gpu_count, params, num_blocks_in_radix, - max_num_radix_in_vec, allocate_gpu_memory, &size_tracker); + max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation, + allocate_gpu_memory, &size_tracker); return size_tracker; } @@ -438,9 +292,7 @@ template __host__ void host_integer_partial_sum_ciphertexts_vec_kb( cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, CudaRadixCiphertextFFI *radix_lwe_out, - CudaRadixCiphertextFFI *terms, - bool reduce_degrees_for_single_carry_propagation, void *const *bsks, - uint64_t *const *ksks, + CudaRadixCiphertextFFI *terms, void *const *bsks, uint64_t *const *ksks, CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key, int_sum_ciphertexts_vec_memory *mem_ptr, uint32_t num_radix_blocks, uint32_t num_radix_in_vec) { @@ -465,10 +317,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( auto d_columns_counter = mem_ptr->d_columns_counter; auto d_new_columns = mem_ptr->d_new_columns; auto d_new_columns_counter = mem_ptr->d_new_columns_counter; - auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in; - auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out; - - auto luts_message_carry = mem_ptr->luts_message_carry; auto glwe_dimension = mem_ptr->params.glwe_dimension; auto polynomial_size = mem_ptr->params.polynomial_size; @@ -483,8 +331,9 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( uint32_t num_many_lut = 1; uint32_t lut_stride = 0; - if (terms->num_radix_blocks == 0) + if (terms->num_radix_blocks == 0) { return; + } if (num_radix_in_vec == 1) { copy_radix_ciphertext_slice_async(streams[0], gpu_indexes[0], radix_lwe_out, 0, num_radix_blocks, @@ -501,10 +350,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( return; } - if (mem_ptr->mem_reuse) { - mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count); - } - if (current_blocks != terms) { copy_radix_ciphertext_async(streams[0], gpu_indexes[0], current_blocks, terms); @@ -523,11 +368,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( radix_columns current_columns(current_blocks->degrees, num_radix_blocks, num_radix_in_vec, chunk_size, needs_processing); - int number_of_threads = min(256, params::degree); + int number_of_threads = std::min(256, params::degree); int part_count = (big_lwe_size + number_of_threads - 1) / number_of_threads; const dim3 number_of_blocks_2d(num_radix_blocks, part_count, 1); + mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count, + num_radix_in_vec, current_blocks->degrees); + while (needs_processing) { + auto luts_message_carry = mem_ptr->luts_message_carry; + auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in; + auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out; calculate_chunks <<>>( (Torus *)(current_blocks->ptr), d_columns, d_columns_counter, @@ -538,8 +389,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( d_pbs_indexes_out, luts_message_carry->get_lut_indexes(0, 0), d_columns, d_columns_counter, chunk_size); - size_t total_ciphertexts; - size_t total_messages; + uint32_t total_ciphertexts; + uint32_t total_messages; current_columns.next_accumulation(luts_message_carry->h_lwe_indexes_in, luts_message_carry->h_lwe_indexes_out, luts_message_carry->h_lut_indexes, @@ -570,9 +421,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( luts_message_carry->broadcast_lut(streams, gpu_indexes, 0); integer_radix_apply_univariate_lookup_table_kb( - streams, gpu_indexes, active_gpu_count, current_blocks, - current_blocks, bsks, ksks, ms_noise_reduction_key, - luts_message_carry, total_ciphertexts); + streams, gpu_indexes, gpu_count, current_blocks, current_blocks, bsks, + ksks, ms_noise_reduction_key, luts_message_carry, total_ciphertexts); } cuda_set_device(gpu_indexes[0]); std::swap(d_columns, d_new_columns); @@ -584,7 +434,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( (Torus *)(radix_lwe_out->ptr), (Torus *)(current_blocks->ptr), d_columns, d_columns_counter, chunk_size, big_lwe_size); - if (reduce_degrees_for_single_carry_propagation) { + if (mem_ptr->reduce_degrees_for_single_carry_propagation) { + auto luts_message_carry = mem_ptr->luts_message_carry; + auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in; + auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out; prepare_final_pbs_indexes <<<1, 2 * num_radix_blocks, 0, streams[0]>>>( d_pbs_indexes_in, d_pbs_indexes_out, @@ -772,9 +625,9 @@ __host__ void host_integer_mult_radix_kb( terms_degree_msb[i] = (b_id > r_id) ? message_modulus - 2 : 0; } host_integer_partial_sum_ciphertexts_vec_kb( - streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, true, - bsks, ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem, - num_blocks, 2 * num_blocks); + streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, bsks, + ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem, num_blocks, + 2 * num_blocks); auto scp_mem_ptr = mem_ptr->sc_prop_mem; uint32_t requested_flag = outputFlag::FLAG_NONE; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh index efdef173a..20ec1fa74 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -117,8 +117,8 @@ __host__ void host_integer_scalar_mul_radix( lwe_array, 0, num_radix_blocks); } else { host_integer_partial_sum_ciphertexts_vec_kb( - streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, true, - bsks, ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem, + streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, bsks, + ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem, num_radix_blocks, j); auto scp_mem_ptr = mem->sc_prop_mem; diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index d0dfaa924..050f5eba8 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1007,6 +1007,7 @@ unsafe extern "C" { message_modulus: u32, carry_modulus: u32, pbs_type: PBS_TYPE, + reduce_degrees_for_single_carry_propagation: bool, allocate_gpu_memory: bool, allocate_ms_array: bool, ) -> u64; @@ -1018,7 +1019,6 @@ unsafe extern "C" { gpu_count: u32, radix_lwe_out: *mut CudaRadixCiphertextFFI, radix_lwe_vec: *mut CudaRadixCiphertextFFI, - reduce_degrees_for_single_carry_propagation: bool, mem_ptr: *mut i8, bsks: *const *mut ffi::c_void, ksks: *const *mut ffi::c_void, diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 2e4e21465..3715f470c 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -4729,6 +4729,7 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async< message_modulus.0 as u32, carry_modulus.0 as u32, pbs_type as u32, + reduce_degrees_for_single_carry_propagation, true, allocate_ms_noise_array, ); @@ -4738,7 +4739,6 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async< streams.len() as u32, &raw mut cuda_ffi_result, &raw mut cuda_ffi_radix_list, - reduce_degrees_for_single_carry_propagation, mem_ptr, bootstrapping_key.ptr.as_ptr(), keyswitch_key.ptr.as_ptr(),