diff --git a/backends/concrete-cuda/implementation/test/test_keyswitch.cpp b/backends/concrete-cuda/implementation/test/test_keyswitch.cpp index b346eae9c..f349d9ea4 100644 --- a/backends/concrete-cuda/implementation/test/test_keyswitch.cpp +++ b/backends/concrete-cuda/implementation/test/test_keyswitch.cpp @@ -19,6 +19,7 @@ typedef struct { int ksk_level; int message_modulus; int carry_modulus; + int number_of_inputs; } KeyswitchTestParams; class KeyswitchTestPrimitives_u64 @@ -31,6 +32,7 @@ protected: int ksk_level; int message_modulus; int carry_modulus; + int number_of_inputs; int payload_modulus; uint64_t delta; Csprng *csprng; @@ -44,7 +46,6 @@ protected: uint64_t *d_lwe_in_ct; uint64_t *lwe_in_ct; uint64_t *lwe_out_ct; - int num_samples; public: // Test arithmetic functions @@ -55,11 +56,12 @@ public: // TestParams input_lwe_dimension = (int)GetParam().input_lwe_dimension; output_lwe_dimension = (int)GetParam().output_lwe_dimension; - noise_variance = (int)GetParam().noise_variance; + noise_variance = (double)GetParam().noise_variance; ksk_base_log = (int)GetParam().ksk_base_log; ksk_level = (int)GetParam().ksk_level; message_modulus = (int)GetParam().message_modulus; carry_modulus = (int)GetParam().carry_modulus; + number_of_inputs = (int)GetParam().number_of_inputs; payload_modulus = message_modulus * carry_modulus; // Value of the shift we multiply our messages by @@ -81,19 +83,21 @@ public: stream, gpu_index, &d_ksk_array, lwe_sk_in_array, lwe_sk_out_array, input_lwe_dimension, output_lwe_dimension, ksk_level, ksk_base_log, csprng, noise_variance, REPETITIONS); - plaintexts = - generate_plaintexts(payload_modulus, delta, 1, REPETITIONS, SAMPLES); + plaintexts = generate_plaintexts(payload_modulus, delta, number_of_inputs, + REPETITIONS, SAMPLES); d_lwe_out_ct = (uint64_t *)cuda_malloc_async( - (output_lwe_dimension + 1) * sizeof(uint64_t), stream, gpu_index); + number_of_inputs * (output_lwe_dimension + 1) * sizeof(uint64_t), + stream, gpu_index); d_lwe_in_ct = (uint64_t *)cuda_malloc_async( - (input_lwe_dimension + 1) * sizeof(uint64_t), stream, gpu_index); + number_of_inputs * (input_lwe_dimension + 1) * sizeof(uint64_t), stream, + gpu_index); - lwe_in_ct = - (uint64_t *)malloc((input_lwe_dimension + 1) * sizeof(uint64_t)); - lwe_out_ct = - (uint64_t *)malloc((output_lwe_dimension + 1) * sizeof(uint64_t)); + lwe_in_ct = (uint64_t *)malloc( + number_of_inputs * (input_lwe_dimension + 1) * sizeof(uint64_t)); + lwe_out_ct = (uint64_t *)malloc( + number_of_inputs * (output_lwe_dimension + 1) * sizeof(uint64_t)); cuda_synchronize_stream(v_stream); } @@ -119,46 +123,57 @@ public: TEST_P(KeyswitchTestPrimitives_u64, keyswitch) { void *v_stream = (void *)stream; for (uint r = 0; r < REPETITIONS; r++) { + uint64_t *lwe_in_sk = + lwe_sk_in_array + (ptrdiff_t)(r * input_lwe_dimension); + uint64_t *lwe_out_sk = + lwe_sk_out_array + (ptrdiff_t)(r * output_lwe_dimension); + int ksk_size = ksk_level * (output_lwe_dimension + 1) * input_lwe_dimension; + uint64_t *d_ksk = d_ksk_array + (ptrdiff_t)(ksk_size * r); for (uint s = 0; s < SAMPLES; s++) { - uint64_t plaintext = plaintexts[r * SAMPLES + s]; - uint64_t *lwe_in_sk = - lwe_sk_in_array + (ptrdiff_t)(r * input_lwe_dimension); - uint64_t *lwe_out_sk = - lwe_sk_out_array + (ptrdiff_t)(r * output_lwe_dimension); - int ksk_size = - ksk_level * (output_lwe_dimension + 1) * input_lwe_dimension; - uint64_t *d_ksk = d_ksk_array + (ptrdiff_t)(ksk_size * r); - concrete_cpu_encrypt_lwe_ciphertext_u64( - lwe_in_sk, lwe_in_ct, plaintext, input_lwe_dimension, noise_variance, - csprng, &CONCRETE_CSPRNG_VTABLE); + for (int i = 0; i < number_of_inputs; i++) { + uint64_t plaintext = plaintexts[r * SAMPLES * number_of_inputs + + s * number_of_inputs + i]; + concrete_cpu_encrypt_lwe_ciphertext_u64( + lwe_in_sk, lwe_in_ct + i * (input_lwe_dimension + 1), plaintext, + input_lwe_dimension, noise_variance, csprng, + &CONCRETE_CSPRNG_VTABLE); + } cuda_synchronize_stream(v_stream); cuda_memcpy_async_to_gpu(d_lwe_in_ct, lwe_in_ct, - (input_lwe_dimension + 1) * sizeof(uint64_t), + number_of_inputs * (input_lwe_dimension + 1) * + sizeof(uint64_t), stream, gpu_index); // Execute keyswitch cuda_keyswitch_lwe_ciphertext_vector_64( stream, gpu_index, (void *)d_lwe_out_ct, (void *)d_lwe_in_ct, (void *)d_ksk, input_lwe_dimension, output_lwe_dimension, - ksk_base_log, ksk_level, 1); + ksk_base_log, ksk_level, number_of_inputs); // Copy result back cuda_memcpy_async_to_cpu(lwe_out_ct, d_lwe_out_ct, - (output_lwe_dimension + 1) * sizeof(uint64_t), + number_of_inputs * (output_lwe_dimension + 1) * + sizeof(uint64_t), stream, gpu_index); - uint64_t decrypted = 0; - concrete_cpu_decrypt_lwe_ciphertext_u64(lwe_out_sk, lwe_out_ct, - output_lwe_dimension, &decrypted); - EXPECT_NE(decrypted, plaintext); - // let err = (decrypted >= plaintext) ? decrypted - plaintext : plaintext - // - decrypted; - // error_sample_vec.push(err); + for (int i = 0; i < number_of_inputs; i++) { + uint64_t plaintext = plaintexts[r * SAMPLES * number_of_inputs + + s * number_of_inputs + i]; + uint64_t decrypted = 0; + concrete_cpu_decrypt_lwe_ciphertext_u64( + lwe_out_sk, lwe_out_ct + i * (output_lwe_dimension + 1), + output_lwe_dimension, &decrypted); + EXPECT_NE(decrypted, plaintext); + // let err = (decrypted >= plaintext) ? decrypted - plaintext : + // plaintext + // - decrypted; + // error_sample_vec.push(err); - // The bit before the message - uint64_t rounding_bit = delta >> 1; - // Compute the rounding bit - uint64_t rounding = (decrypted & rounding_bit) << 1; - uint64_t decoded = (decrypted + rounding) / delta; - EXPECT_EQ(decoded, plaintext / delta); + // The bit before the message + uint64_t rounding_bit = delta >> 1; + // Compute the rounding bit + uint64_t rounding = (decrypted & rounding_bit) << 1; + uint64_t decoded = (decrypted + rounding) / delta; + EXPECT_EQ(decoded, plaintext / delta); + } } } } @@ -168,13 +183,19 @@ TEST_P(KeyswitchTestPrimitives_u64, keyswitch) { ::testing::internal::ParamGenerator ksk_params_u64 = ::testing::Values( // n, k*N, noise_variance, ks_base_log, ks_level, - // message_modulus, carry_modulus - (KeyswitchTestParams){567, 1280, 2.9802322387695312e-08, 3, 3, 2, 1}, - (KeyswitchTestParams){694, 1536, 2.9802322387695312e-08, 4, 3, 2, 1}, - (KeyswitchTestParams){769, 2048, 2.9802322387695312e-08, 4, 3, 2, 1}, - (KeyswitchTestParams){754, 2048, 2.9802322387695312e-08, 3, 5, 2, 1}, - (KeyswitchTestParams){847, 4096, 2.9802322387695312e-08, 4, 4, 2, 1}, - (KeyswitchTestParams){881, 8192, 2.9802322387695312e-08, 3, 6, 2, 1}); + // message_modulus, carry_modulus, number_of_inputs + (KeyswitchTestParams){567, 1280, 2.9802322387695312e-18, 3, 3, 2, 1, + 10}, + (KeyswitchTestParams){694, 1536, 2.9802322387695312e-18, 4, 3, 2, 1, + 10}, + (KeyswitchTestParams){769, 2048, 2.9802322387695312e-18, 4, 3, 2, 1, + 10}, + (KeyswitchTestParams){754, 2048, 2.9802322387695312e-18, 3, 5, 2, 1, + 10}, + (KeyswitchTestParams){847, 4096, 2.9802322387695312e-18, 4, 4, 2, 1, + 10}, + (KeyswitchTestParams){881, 8192, 2.9802322387695312e-18, 3, 6, 2, 1, + 10}); std::string printParamName(::testing::TestParamInfo p) { KeyswitchTestParams params = p.param;