mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
chore(gpu): adjust internal tools
This commit is contained in:
@@ -45,7 +45,7 @@ protected:
|
||||
uint64_t *lwe_sk_in_array;
|
||||
uint64_t *lwe_sk_out_array;
|
||||
uint64_t *plaintexts;
|
||||
uint64_t *d_bsk;
|
||||
double2 *d_bsk;
|
||||
uint64_t *d_lut_pbs_identity;
|
||||
uint64_t *d_lut_pbs_indexes;
|
||||
uint64_t *d_lwe_ct_in_array;
|
||||
|
||||
@@ -27,7 +27,7 @@ void programmable_bootstrap_classical_teardown(
|
||||
uint64_t *d_lwe_output_indexes);
|
||||
void programmable_bootstrap_multibit_setup(
|
||||
cuda_stream_t *stream, Seed *seed, uint64_t **lwe_sk_in_array,
|
||||
uint64_t **lwe_sk_out_array, uint64_t **d_bsk_array, uint64_t **plaintexts,
|
||||
uint64_t **lwe_sk_out_array, double2 **d_bsk_array, uint64_t **plaintexts,
|
||||
uint64_t **d_lut_pbs_identity, uint64_t **d_lut_pbs_indexes,
|
||||
uint64_t **d_lwe_ct_in_array, uint64_t **d_lwe_input_indexes,
|
||||
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes,
|
||||
@@ -39,7 +39,7 @@ void programmable_bootstrap_multibit_setup(
|
||||
int chunk_size = 0);
|
||||
void programmable_bootstrap_multibit_teardown(
|
||||
cuda_stream_t *stream, uint64_t *lwe_sk_in_array,
|
||||
uint64_t *lwe_sk_out_array, uint64_t *d_bsk_array, uint64_t *plaintexts,
|
||||
uint64_t *lwe_sk_out_array, double2 *d_bsk_array, uint64_t *plaintexts,
|
||||
uint64_t *d_lut_pbs_identity, uint64_t *d_lut_pbs_indexes,
|
||||
uint64_t *d_lwe_ct_in_array, uint64_t *d_lwe_input_indexes,
|
||||
uint64_t *d_lwe_ct_out_array, uint64_t *d_lwe_output_indexes);
|
||||
|
||||
@@ -37,7 +37,7 @@ void generate_lwe_programmable_bootstrap_keys(
|
||||
const unsigned repetitions);
|
||||
|
||||
void generate_lwe_multi_bit_programmable_bootstrap_keys(
|
||||
cuda_stream_t *stream, uint64_t **d_bsk_array, uint64_t *lwe_sk_in_array,
|
||||
cuda_stream_t *stream, double2 **d_bsk_array, uint64_t *lwe_sk_in_array,
|
||||
uint64_t *lwe_sk_out_array, int lwe_dimension, int glwe_dimension,
|
||||
int polynomial_size, int pbs_level, int pbs_base_log, int grouping_factor,
|
||||
Seed *seed, DynamicDistribution noise_distribution,
|
||||
|
||||
@@ -130,7 +130,7 @@ void programmable_bootstrap_classical_teardown(
|
||||
|
||||
void programmable_bootstrap_multibit_setup(
|
||||
cuda_stream_t *stream, Seed *seed, uint64_t **lwe_sk_in_array,
|
||||
uint64_t **lwe_sk_out_array, uint64_t **d_bsk_array, uint64_t **plaintexts,
|
||||
uint64_t **lwe_sk_out_array, double2 **d_bsk_array, uint64_t **plaintexts,
|
||||
uint64_t **d_lut_pbs_identity, uint64_t **d_lut_pbs_indexes,
|
||||
uint64_t **d_lwe_ct_in_array, uint64_t **d_lwe_input_indexes,
|
||||
uint64_t **d_lwe_ct_out_array, uint64_t **d_lwe_output_indexes,
|
||||
@@ -236,7 +236,7 @@ void programmable_bootstrap_multibit_setup(
|
||||
|
||||
void programmable_bootstrap_multibit_teardown(
|
||||
cuda_stream_t *stream, uint64_t *lwe_sk_in_array,
|
||||
uint64_t *lwe_sk_out_array, uint64_t *d_bsk_array, uint64_t *plaintexts,
|
||||
uint64_t *lwe_sk_out_array, double2 *d_bsk_array, uint64_t *plaintexts,
|
||||
uint64_t *d_lut_pbs_identity, uint64_t *d_lut_pbs_indexes,
|
||||
uint64_t *d_lwe_ct_in_array, uint64_t *d_lwe_input_indexes,
|
||||
uint64_t *d_lwe_ct_out_array, uint64_t *d_lwe_output_indexes) {
|
||||
|
||||
@@ -44,7 +44,7 @@ protected:
|
||||
uint64_t *lwe_sk_in_array;
|
||||
uint64_t *lwe_sk_out_array;
|
||||
uint64_t *plaintexts;
|
||||
uint64_t *d_bsk_array;
|
||||
double2 *d_bsk_array;
|
||||
uint64_t *d_lut_pbs_identity;
|
||||
uint64_t *d_lut_pbs_indexes;
|
||||
uint64_t *d_lwe_ct_in_array;
|
||||
@@ -120,7 +120,7 @@ TEST_P(MultiBitProgrammableBootstrapTestPrimitives_u64,
|
||||
(1 << grouping_factor);
|
||||
|
||||
for (int r = 0; r < repetitions; r++) {
|
||||
uint64_t *d_bsk = d_bsk_array + (ptrdiff_t)(bsk_size * r);
|
||||
double2 *d_bsk = d_bsk_array + (ptrdiff_t)(bsk_size * r);
|
||||
uint64_t *lwe_sk_out =
|
||||
lwe_sk_out_array + (ptrdiff_t)(r * glwe_dimension * polynomial_size);
|
||||
for (int s = 0; s < samples; s++) {
|
||||
|
||||
@@ -176,7 +176,7 @@ void generate_lwe_programmable_bootstrap_keys(cuda_stream_t *stream,
|
||||
}
|
||||
|
||||
void generate_lwe_multi_bit_programmable_bootstrap_keys(
|
||||
cuda_stream_t *stream, uint64_t **d_bsk_array, uint64_t *lwe_sk_in_array,
|
||||
cuda_stream_t *stream, double2 **d_bsk_array, uint64_t *lwe_sk_in_array,
|
||||
uint64_t *lwe_sk_out_array, int lwe_dimension, int glwe_dimension,
|
||||
int polynomial_size, int grouping_factor, int pbs_level, int pbs_base_log,
|
||||
Seed *seed, DynamicDistribution noise_distribution,
|
||||
@@ -189,7 +189,7 @@ void generate_lwe_multi_bit_programmable_bootstrap_keys(
|
||||
uint64_t *bsk_array = (uint64_t *)malloc(bsk_array_size * sizeof(uint64_t));
|
||||
|
||||
*d_bsk_array =
|
||||
(uint64_t *)cuda_malloc_async(bsk_array_size * sizeof(uint64_t), stream);
|
||||
(double2 *)cuda_malloc_async(bsk_array_size * sizeof(double), stream);
|
||||
for (uint r = 0; r < repetitions; r++) {
|
||||
int shift_in = 0;
|
||||
int shift_out = 0;
|
||||
@@ -199,14 +199,14 @@ void generate_lwe_multi_bit_programmable_bootstrap_keys(
|
||||
lwe_sk_out_array + (ptrdiff_t)(shift_out), glwe_dimension,
|
||||
polynomial_size, bsk_array + (ptrdiff_t)(shift_bsk), pbs_base_log,
|
||||
pbs_level, grouping_factor, noise_distribution, 0, 0);
|
||||
uint64_t *d_bsk = *d_bsk_array + (ptrdiff_t)(shift_bsk);
|
||||
double2 *d_bsk = *d_bsk_array + (ptrdiff_t)(shift_bsk);
|
||||
uint64_t *bsk = bsk_array + (ptrdiff_t)(shift_bsk);
|
||||
cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64(
|
||||
d_bsk, bsk, stream, lwe_dimension, glwe_dimension, pbs_level,
|
||||
polynomial_size, grouping_factor);
|
||||
shift_in += lwe_dimension;
|
||||
shift_out += glwe_dimension * polynomial_size;
|
||||
shift_bsk += bsk_size;
|
||||
shift_bsk += bsk_size / 2;
|
||||
}
|
||||
cuda_synchronize_stream(stream);
|
||||
free(bsk_array);
|
||||
|
||||
Reference in New Issue
Block a user