mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 03:25:05 -05:00
test(concrete_cuda): enhance the keyswitch test
It now tests execution on a vector of inputs. The noise is reduced so the test is not as flaky as before.
This commit is contained in:
@@ -19,6 +19,7 @@ typedef struct {
|
||||
int ksk_level;
|
||||
int message_modulus;
|
||||
int carry_modulus;
|
||||
int number_of_inputs;
|
||||
} KeyswitchTestParams;
|
||||
|
||||
class KeyswitchTestPrimitives_u64
|
||||
@@ -31,6 +32,7 @@ protected:
|
||||
int ksk_level;
|
||||
int message_modulus;
|
||||
int carry_modulus;
|
||||
int number_of_inputs;
|
||||
int payload_modulus;
|
||||
uint64_t delta;
|
||||
Csprng *csprng;
|
||||
@@ -44,7 +46,6 @@ protected:
|
||||
uint64_t *d_lwe_in_ct;
|
||||
uint64_t *lwe_in_ct;
|
||||
uint64_t *lwe_out_ct;
|
||||
int num_samples;
|
||||
|
||||
public:
|
||||
// Test arithmetic functions
|
||||
@@ -55,11 +56,12 @@ public:
|
||||
// TestParams
|
||||
input_lwe_dimension = (int)GetParam().input_lwe_dimension;
|
||||
output_lwe_dimension = (int)GetParam().output_lwe_dimension;
|
||||
noise_variance = (int)GetParam().noise_variance;
|
||||
noise_variance = (double)GetParam().noise_variance;
|
||||
ksk_base_log = (int)GetParam().ksk_base_log;
|
||||
ksk_level = (int)GetParam().ksk_level;
|
||||
message_modulus = (int)GetParam().message_modulus;
|
||||
carry_modulus = (int)GetParam().carry_modulus;
|
||||
number_of_inputs = (int)GetParam().number_of_inputs;
|
||||
|
||||
payload_modulus = message_modulus * carry_modulus;
|
||||
// Value of the shift we multiply our messages by
|
||||
@@ -81,19 +83,21 @@ public:
|
||||
stream, gpu_index, &d_ksk_array, lwe_sk_in_array, lwe_sk_out_array,
|
||||
input_lwe_dimension, output_lwe_dimension, ksk_level, ksk_base_log,
|
||||
csprng, noise_variance, REPETITIONS);
|
||||
plaintexts =
|
||||
generate_plaintexts(payload_modulus, delta, 1, REPETITIONS, SAMPLES);
|
||||
plaintexts = generate_plaintexts(payload_modulus, delta, number_of_inputs,
|
||||
REPETITIONS, SAMPLES);
|
||||
|
||||
d_lwe_out_ct = (uint64_t *)cuda_malloc_async(
|
||||
(output_lwe_dimension + 1) * sizeof(uint64_t), stream, gpu_index);
|
||||
number_of_inputs * (output_lwe_dimension + 1) * sizeof(uint64_t),
|
||||
stream, gpu_index);
|
||||
|
||||
d_lwe_in_ct = (uint64_t *)cuda_malloc_async(
|
||||
(input_lwe_dimension + 1) * sizeof(uint64_t), stream, gpu_index);
|
||||
number_of_inputs * (input_lwe_dimension + 1) * sizeof(uint64_t), stream,
|
||||
gpu_index);
|
||||
|
||||
lwe_in_ct =
|
||||
(uint64_t *)malloc((input_lwe_dimension + 1) * sizeof(uint64_t));
|
||||
lwe_out_ct =
|
||||
(uint64_t *)malloc((output_lwe_dimension + 1) * sizeof(uint64_t));
|
||||
lwe_in_ct = (uint64_t *)malloc(
|
||||
number_of_inputs * (input_lwe_dimension + 1) * sizeof(uint64_t));
|
||||
lwe_out_ct = (uint64_t *)malloc(
|
||||
number_of_inputs * (output_lwe_dimension + 1) * sizeof(uint64_t));
|
||||
|
||||
cuda_synchronize_stream(v_stream);
|
||||
}
|
||||
@@ -119,46 +123,57 @@ public:
|
||||
TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
|
||||
void *v_stream = (void *)stream;
|
||||
for (uint r = 0; r < REPETITIONS; r++) {
|
||||
uint64_t *lwe_in_sk =
|
||||
lwe_sk_in_array + (ptrdiff_t)(r * input_lwe_dimension);
|
||||
uint64_t *lwe_out_sk =
|
||||
lwe_sk_out_array + (ptrdiff_t)(r * output_lwe_dimension);
|
||||
int ksk_size = ksk_level * (output_lwe_dimension + 1) * input_lwe_dimension;
|
||||
uint64_t *d_ksk = d_ksk_array + (ptrdiff_t)(ksk_size * r);
|
||||
for (uint s = 0; s < SAMPLES; s++) {
|
||||
uint64_t plaintext = plaintexts[r * SAMPLES + s];
|
||||
uint64_t *lwe_in_sk =
|
||||
lwe_sk_in_array + (ptrdiff_t)(r * input_lwe_dimension);
|
||||
uint64_t *lwe_out_sk =
|
||||
lwe_sk_out_array + (ptrdiff_t)(r * output_lwe_dimension);
|
||||
int ksk_size =
|
||||
ksk_level * (output_lwe_dimension + 1) * input_lwe_dimension;
|
||||
uint64_t *d_ksk = d_ksk_array + (ptrdiff_t)(ksk_size * r);
|
||||
concrete_cpu_encrypt_lwe_ciphertext_u64(
|
||||
lwe_in_sk, lwe_in_ct, plaintext, input_lwe_dimension, noise_variance,
|
||||
csprng, &CONCRETE_CSPRNG_VTABLE);
|
||||
for (int i = 0; i < number_of_inputs; i++) {
|
||||
uint64_t plaintext = plaintexts[r * SAMPLES * number_of_inputs +
|
||||
s * number_of_inputs + i];
|
||||
concrete_cpu_encrypt_lwe_ciphertext_u64(
|
||||
lwe_in_sk, lwe_in_ct + i * (input_lwe_dimension + 1), plaintext,
|
||||
input_lwe_dimension, noise_variance, csprng,
|
||||
&CONCRETE_CSPRNG_VTABLE);
|
||||
}
|
||||
cuda_synchronize_stream(v_stream);
|
||||
cuda_memcpy_async_to_gpu(d_lwe_in_ct, lwe_in_ct,
|
||||
(input_lwe_dimension + 1) * sizeof(uint64_t),
|
||||
number_of_inputs * (input_lwe_dimension + 1) *
|
||||
sizeof(uint64_t),
|
||||
stream, gpu_index);
|
||||
// Execute keyswitch
|
||||
cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
stream, gpu_index, (void *)d_lwe_out_ct, (void *)d_lwe_in_ct,
|
||||
(void *)d_ksk, input_lwe_dimension, output_lwe_dimension,
|
||||
ksk_base_log, ksk_level, 1);
|
||||
ksk_base_log, ksk_level, number_of_inputs);
|
||||
|
||||
// Copy result back
|
||||
cuda_memcpy_async_to_cpu(lwe_out_ct, d_lwe_out_ct,
|
||||
(output_lwe_dimension + 1) * sizeof(uint64_t),
|
||||
number_of_inputs * (output_lwe_dimension + 1) *
|
||||
sizeof(uint64_t),
|
||||
stream, gpu_index);
|
||||
uint64_t decrypted = 0;
|
||||
concrete_cpu_decrypt_lwe_ciphertext_u64(lwe_out_sk, lwe_out_ct,
|
||||
output_lwe_dimension, &decrypted);
|
||||
EXPECT_NE(decrypted, plaintext);
|
||||
// let err = (decrypted >= plaintext) ? decrypted - plaintext : plaintext
|
||||
// - decrypted;
|
||||
// error_sample_vec.push(err);
|
||||
for (int i = 0; i < number_of_inputs; i++) {
|
||||
uint64_t plaintext = plaintexts[r * SAMPLES * number_of_inputs +
|
||||
s * number_of_inputs + i];
|
||||
uint64_t decrypted = 0;
|
||||
concrete_cpu_decrypt_lwe_ciphertext_u64(
|
||||
lwe_out_sk, lwe_out_ct + i * (output_lwe_dimension + 1),
|
||||
output_lwe_dimension, &decrypted);
|
||||
EXPECT_NE(decrypted, plaintext);
|
||||
// let err = (decrypted >= plaintext) ? decrypted - plaintext :
|
||||
// plaintext
|
||||
// - decrypted;
|
||||
// error_sample_vec.push(err);
|
||||
|
||||
// The bit before the message
|
||||
uint64_t rounding_bit = delta >> 1;
|
||||
// Compute the rounding bit
|
||||
uint64_t rounding = (decrypted & rounding_bit) << 1;
|
||||
uint64_t decoded = (decrypted + rounding) / delta;
|
||||
EXPECT_EQ(decoded, plaintext / delta);
|
||||
// The bit before the message
|
||||
uint64_t rounding_bit = delta >> 1;
|
||||
// Compute the rounding bit
|
||||
uint64_t rounding = (decrypted & rounding_bit) << 1;
|
||||
uint64_t decoded = (decrypted + rounding) / delta;
|
||||
EXPECT_EQ(decoded, plaintext / delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,13 +183,19 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) {
|
||||
::testing::internal::ParamGenerator<KeyswitchTestParams> ksk_params_u64 =
|
||||
::testing::Values(
|
||||
// n, k*N, noise_variance, ks_base_log, ks_level,
|
||||
// message_modulus, carry_modulus
|
||||
(KeyswitchTestParams){567, 1280, 2.9802322387695312e-08, 3, 3, 2, 1},
|
||||
(KeyswitchTestParams){694, 1536, 2.9802322387695312e-08, 4, 3, 2, 1},
|
||||
(KeyswitchTestParams){769, 2048, 2.9802322387695312e-08, 4, 3, 2, 1},
|
||||
(KeyswitchTestParams){754, 2048, 2.9802322387695312e-08, 3, 5, 2, 1},
|
||||
(KeyswitchTestParams){847, 4096, 2.9802322387695312e-08, 4, 4, 2, 1},
|
||||
(KeyswitchTestParams){881, 8192, 2.9802322387695312e-08, 3, 6, 2, 1});
|
||||
// message_modulus, carry_modulus, number_of_inputs
|
||||
(KeyswitchTestParams){567, 1280, 2.9802322387695312e-18, 3, 3, 2, 1,
|
||||
10},
|
||||
(KeyswitchTestParams){694, 1536, 2.9802322387695312e-18, 4, 3, 2, 1,
|
||||
10},
|
||||
(KeyswitchTestParams){769, 2048, 2.9802322387695312e-18, 4, 3, 2, 1,
|
||||
10},
|
||||
(KeyswitchTestParams){754, 2048, 2.9802322387695312e-18, 3, 5, 2, 1,
|
||||
10},
|
||||
(KeyswitchTestParams){847, 4096, 2.9802322387695312e-18, 4, 4, 2, 1,
|
||||
10},
|
||||
(KeyswitchTestParams){881, 8192, 2.9802322387695312e-18, 3, 6, 2, 1,
|
||||
10});
|
||||
|
||||
std::string printParamName(::testing::TestParamInfo<KeyswitchTestParams> p) {
|
||||
KeyswitchTestParams params = p.param;
|
||||
|
||||
Reference in New Issue
Block a user