mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor(cuda): change lut_vector_indexes type to Torus
This commit is contained in:
@@ -180,7 +180,7 @@ void cuda_extract_bits_64(
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
glwe_dimension, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_samples, max_shared_memory);
|
||||
@@ -192,7 +192,7 @@ void cuda_extract_bits_64(
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
glwe_dimension, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_samples, max_shared_memory);
|
||||
@@ -204,7 +204,7 @@ void cuda_extract_bits_64(
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
glwe_dimension, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_samples, max_shared_memory);
|
||||
@@ -216,7 +216,7 @@ void cuda_extract_bits_64(
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
glwe_dimension, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_samples, max_shared_memory);
|
||||
@@ -228,7 +228,7 @@ void cuda_extract_bits_64(
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
glwe_dimension, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_samples, max_shared_memory);
|
||||
|
||||
@@ -138,7 +138,7 @@ __host__ void host_extract_bits(
|
||||
Torus *lwe_array_in, Torus *lwe_array_in_buffer,
|
||||
Torus *lwe_array_in_shifted_buffer, Torus *lwe_array_out_ks_buffer,
|
||||
Torus *lwe_array_out_pbs_buffer, Torus *lut_pbs,
|
||||
uint32_t *lut_vector_indexes, Torus *ksk, double2 *fourier_bsk,
|
||||
Torus *lut_vector_indexes, Torus *ksk, double2 *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
|
||||
@@ -149,7 +149,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
case 512:
|
||||
host_bootstrap_amortized<uint64_t, Degree<512>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
lwe_idx, max_shared_memory);
|
||||
@@ -157,7 +157,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
case 1024:
|
||||
host_bootstrap_amortized<uint64_t, Degree<1024>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
lwe_idx, max_shared_memory);
|
||||
@@ -165,7 +165,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
case 2048:
|
||||
host_bootstrap_amortized<uint64_t, Degree<2048>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
lwe_idx, max_shared_memory);
|
||||
@@ -173,7 +173,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
case 4096:
|
||||
host_bootstrap_amortized<uint64_t, Degree<4096>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
lwe_idx, max_shared_memory);
|
||||
@@ -181,7 +181,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
case 8192:
|
||||
host_bootstrap_amortized<uint64_t, Degree<8192>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
lwe_idx, max_shared_memory);
|
||||
|
||||
@@ -51,7 +51,7 @@ template <typename Torus, class params, sharedMemDegree SMD>
|
||||
* is not FULLSM
|
||||
*/
|
||||
__global__ void device_bootstrap_amortized(
|
||||
Torus *lwe_array_out, Torus *lut_vector, uint32_t *lut_vector_indexes,
|
||||
Torus *lwe_array_out, Torus *lut_vector, Torus *lut_vector_indexes,
|
||||
Torus *lwe_array_in, double2 *bootstrapping_key, char *device_mem,
|
||||
uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t lwe_idx,
|
||||
@@ -216,9 +216,9 @@ __global__ void device_bootstrap_amortized(
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_bootstrap_amortized(
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
Torus *lut_vector_indexes, Torus *lwe_array_in, double2 *bootstrapping_key,
|
||||
uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
case 512:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<512>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
max_shared_memory);
|
||||
@@ -191,7 +191,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
case 1024:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<1024>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
max_shared_memory);
|
||||
@@ -199,7 +199,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
case 2048:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<2048>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
max_shared_memory);
|
||||
@@ -207,7 +207,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
case 4096:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<4096>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
max_shared_memory);
|
||||
@@ -215,7 +215,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
case 8192:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<8192>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, glwe_dimension, lwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_samples, num_lut_vectors,
|
||||
max_shared_memory);
|
||||
|
||||
@@ -248,7 +248,7 @@ __global__ void device_bootstrap_low_latency(
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_bootstrap_low_latency(
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
|
||||
Torus *lut_vector_indexes, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors,
|
||||
|
||||
@@ -158,7 +158,7 @@ void cuda_circuit_bootstrap_64(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
@@ -169,7 +169,7 @@ void cuda_circuit_bootstrap_64(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
@@ -180,7 +180,7 @@ void cuda_circuit_bootstrap_64(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
@@ -191,7 +191,7 @@ void cuda_circuit_bootstrap_64(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
@@ -202,7 +202,7 @@ void cuda_circuit_bootstrap_64(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
|
||||
@@ -106,7 +106,7 @@ __host__ void host_circuit_bootstrap(
|
||||
void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in,
|
||||
double2 *fourier_bsk, Torus *fp_ksk_array,
|
||||
Torus *lwe_array_in_shifted_buffer, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer,
|
||||
Torus *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer,
|
||||
Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
|
||||
@@ -67,16 +67,16 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
level_count_cbs * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
stream, gpu_index);
|
||||
// indexes of lut vectors for cbs
|
||||
uint32_t *h_lut_vector_indexes =
|
||||
(uint32_t *)malloc(number_of_inputs * level_count_cbs * sizeof(uint32_t));
|
||||
Torus *h_lut_vector_indexes =
|
||||
(Torus *)malloc(number_of_inputs * level_count_cbs * sizeof(Torus));
|
||||
for (uint index = 0; index < level_count_cbs * number_of_inputs; index++) {
|
||||
h_lut_vector_indexes[index] = index % level_count_cbs;
|
||||
}
|
||||
uint32_t *lut_vector_indexes = (uint32_t *)cuda_malloc_async(
|
||||
number_of_inputs * level_count_cbs * sizeof(uint32_t), stream, gpu_index);
|
||||
Torus *lut_vector_indexes = (Torus *)cuda_malloc_async(
|
||||
number_of_inputs * level_count_cbs * sizeof(Torus), stream, gpu_index);
|
||||
cuda_memcpy_async_to_gpu(
|
||||
lut_vector_indexes, h_lut_vector_indexes,
|
||||
number_of_inputs * level_count_cbs * sizeof(uint32_t), stream, gpu_index);
|
||||
number_of_inputs * level_count_cbs * sizeof(Torus), stream, gpu_index);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
uint32_t bits = sizeof(Torus) * 8;
|
||||
@@ -145,12 +145,12 @@ __host__ void host_wop_pbs(
|
||||
|
||||
// let mut h_lut_vector_indexes = vec![0 as u32; 1];
|
||||
// indexes of lut vectors for bit extract
|
||||
uint32_t *h_lut_vector_indexes = (uint32_t *)malloc(sizeof(uint32_t));
|
||||
Torus *h_lut_vector_indexes = (Torus *)malloc(sizeof(Torus));
|
||||
h_lut_vector_indexes[0] = 0;
|
||||
uint32_t *lut_vector_indexes =
|
||||
(uint32_t *)cuda_malloc_async(sizeof(uint32_t), stream, gpu_index);
|
||||
Torus *lut_vector_indexes =
|
||||
(Torus *)cuda_malloc_async(sizeof(Torus), stream, gpu_index);
|
||||
cuda_memcpy_async_to_gpu(lut_vector_indexes, h_lut_vector_indexes,
|
||||
sizeof(uint32_t), stream, gpu_index);
|
||||
sizeof(Torus), stream, gpu_index);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
Torus *lut_pbs = (Torus *)cuda_malloc_async(
|
||||
(2 * polynomial_size) * sizeof(Torus), stream, gpu_index);
|
||||
|
||||
Reference in New Issue
Block a user