mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
fix(gpu): return to 64 regs in multi-bit pbs
This commit is contained in:
@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
Torus *global_accumulator, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
|
||||
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
|
||||
uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input,
|
||||
int8_t *device_mem, uint64_t device_memory_size_per_block,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
|
||||
@@ -321,8 +321,9 @@ __host__ void execute_cg_external_product_loop(
|
||||
lwe_chunk_size * level_count * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * (polynomial_size / 2);
|
||||
|
||||
uint64_t chunk_size = std::min(
|
||||
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
|
||||
uint32_t chunk_size = (uint32_t)(std::min(
|
||||
lwe_chunk_size,
|
||||
(uint64_t)(lwe_dimension / grouping_factor) - lwe_offset));
|
||||
|
||||
auto d_mem = buffer->d_mem_acc_cg;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
|
||||
@@ -373,7 +373,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const double2 *__restrict__ keybundle_array, Torus *global_accumulator,
|
||||
double2 *join_buffer, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t level_count, uint32_t iteration, uint64_t lwe_chunk_size,
|
||||
uint32_t level_count, uint32_t iteration, uint32_t lwe_chunk_size,
|
||||
int8_t *device_mem, uint64_t device_memory_size_per_block,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
// We use shared memory for the polynomials that are used often during the
|
||||
@@ -790,7 +790,7 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
uint32_t lut_stride) {
|
||||
cuda_set_device(gpu_index);
|
||||
|
||||
auto lwe_chunk_size = buffer->lwe_chunk_size;
|
||||
uint32_t lwe_chunk_size = (uint32_t)(buffer->lwe_chunk_size);
|
||||
uint64_t full_sm_accumulate_step_two =
|
||||
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
|
||||
polynomial_size);
|
||||
|
||||
@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
Torus *global_accumulator, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
|
||||
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
|
||||
uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input,
|
||||
int8_t *device_mem, uint64_t device_memory_size_per_block,
|
||||
bool support_dsm, uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
|
||||
@@ -205,10 +205,10 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
const Torus *__restrict__ lut_vector_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
|
||||
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
|
||||
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
const double2 *__restrict__ keybundle_array, Torus *global_accumulator,
|
||||
uint32_t lwe_dimension, uint32_t lwe_offset, uint32_t lwe_chunk_size,
|
||||
uint64_t keybundle_size_per_input, uint32_t num_many_lut,
|
||||
uint32_t lut_stride) {
|
||||
|
||||
constexpr uint32_t level_count = 1;
|
||||
constexpr uint32_t grouping_factor = 4;
|
||||
@@ -548,8 +548,9 @@ __host__ void execute_tbc_external_product_loop(
|
||||
lwe_chunk_size * level_count * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * (polynomial_size / 2);
|
||||
|
||||
uint64_t chunk_size = std::min(
|
||||
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
|
||||
uint32_t chunk_size = (uint32_t)(std::min(
|
||||
lwe_chunk_size,
|
||||
(uint64_t)(lwe_dimension / grouping_factor) - lwe_offset));
|
||||
|
||||
auto d_mem = buffer->d_mem_acc_tbc;
|
||||
auto keybundle_fft = buffer->keybundle_fft;
|
||||
@@ -624,9 +625,9 @@ __host__ void execute_tbc_external_product_loop(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
|
||||
global_accumulator, lwe_dimension, lwe_offset, chunk_size,
|
||||
keybundle_size_per_input, num_many_lut, lut_stride));
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, global_accumulator,
|
||||
lwe_dimension, lwe_offset, chunk_size, keybundle_size_per_input,
|
||||
num_many_lut, lut_stride));
|
||||
} else {
|
||||
check_cuda_error(cudaLaunchKernelEx(
|
||||
&config,
|
||||
|
||||
Reference in New Issue
Block a user