From 4ff95e3a424028046f6cc135cc792c9f30826ebc Mon Sep 17 00:00:00 2001 From: Enzo Di Maria Date: Mon, 3 Nov 2025 15:52:43 +0100 Subject: [PATCH] feat(gpu): AES 256 --- .../tfhe-cuda-backend/cuda/include/aes/aes.h | 23 ++ .../cuda/include/aes/aes_utilities.h | 63 +++ .../tfhe-cuda-backend/cuda/src/aes/aes256.cu | 55 +++ .../tfhe-cuda-backend/cuda/src/aes/aes256.cuh | 355 +++++++++++++++++ backends/tfhe-cuda-backend/src/bindings.rs | 48 +++ tfhe-benchmark/benches/integer/aes.rs | 183 +++++---- tfhe-benchmark/benches/integer/aes256.rs | 143 +++++++ tfhe-benchmark/benches/integer/bench.rs | 3 + tfhe/src/integer/gpu/mod.rs | 226 +++++++++++ tfhe/src/integer/gpu/server_key/radix/aes.rs | 53 +++ .../integer/gpu/server_key/radix/aes256.rs | 364 ++++++++++++++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 1 + .../server_key/radix/tests_unsigned/mod.rs | 1 + .../radix/tests_unsigned/test_aes.rs | 16 +- .../radix/tests_unsigned/test_aes256.rs | 59 +++ .../radix_parallel/tests_cases_unsigned.rs | 5 + .../radix_parallel/tests_unsigned/mod.rs | 1 + .../radix_parallel/tests_unsigned/test_aes.rs | 2 +- .../tests_unsigned/test_aes256.rs | 242 ++++++++++++ 19 files changed, 1740 insertions(+), 103 deletions(-) create mode 100644 backends/tfhe-cuda-backend/cuda/src/aes/aes256.cu create mode 100644 backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh create mode 100644 tfhe-benchmark/benches/integer/aes256.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/aes256.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes256.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes256.rs diff --git a/backends/tfhe-cuda-backend/cuda/include/aes/aes.h b/backends/tfhe-cuda-backend/cuda/include/aes/aes.h index a491e73fc..9415ce4c4 100644 --- a/backends/tfhe-cuda-backend/cuda/include/aes/aes.h +++ b/backends/tfhe-cuda-backend/cuda/include/aes/aes.h @@ -39,6 +39,29 @@ void cuda_integer_key_expansion_64(CudaStreamsFFI streams, void cleanup_cuda_integer_key_expansion_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); + +void cuda_integer_aes_ctr_256_encrypt_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys, + const uint64_t *counter_bits_le_all_blocks, uint32_t num_aes_inputs, + int8_t *mem_ptr, void *const *bsks, void *const *ksks); + +uint64_t scratch_cuda_integer_key_expansion_256_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_integer_key_expansion_256_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_integer_key_expansion_256_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); } #endif diff --git a/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h b/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h index 61f1d0afc..85cae5081 100644 --- a/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h @@ -442,4 +442,67 @@ template struct int_key_expansion_buffer { } }; +template struct int_key_expansion_256_buffer { + int_radix_params params; + bool allocate_gpu_memory; + + CudaRadixCiphertextFFI *words_buffer; + + CudaRadixCiphertextFFI *tmp_word_buffer; + CudaRadixCiphertextFFI *tmp_rotated_word_buffer; + + int_aes_encrypt_buffer *aes_encrypt_buffer; + + int_key_expansion_256_buffer(CudaStreams streams, + const int_radix_params ¶ms, + bool allocate_gpu_memory, + uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + + constexpr uint32_t TOTAL_WORDS = 60; + constexpr uint32_t BITS_PER_WORD = 32; + constexpr uint32_t TOTAL_BITS = TOTAL_WORDS * BITS_PER_WORD; + + this->words_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->words_buffer, TOTAL_BITS, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->tmp_word_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_word_buffer, + BITS_PER_WORD, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->tmp_rotated_word_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_rotated_word_buffer, + BITS_PER_WORD, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->aes_encrypt_buffer = new int_aes_encrypt_buffer( + streams, params, allocate_gpu_memory, 1, 4, size_tracker); + } + + void release(CudaStreams streams) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->words_buffer, allocate_gpu_memory); + delete this->words_buffer; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_word_buffer, allocate_gpu_memory); + delete this->tmp_word_buffer; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_rotated_word_buffer, + allocate_gpu_memory); + delete this->tmp_rotated_word_buffer; + + this->aes_encrypt_buffer->release(streams); + delete this->aes_encrypt_buffer; + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + #endif diff --git a/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cu b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cu new file mode 100644 index 000000000..e79c2a9e5 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cu @@ -0,0 +1,55 @@ +#include "../../include/aes/aes.h" +#include "aes256.cuh" + +void cuda_integer_aes_ctr_256_encrypt_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys, + const uint64_t *counter_bits_le_all_blocks, uint32_t num_aes_inputs, + int8_t *mem_ptr, void *const *bsks, void *const *ksks) { + + host_integer_aes_ctr_256_encrypt( + CudaStreams(streams), output, iv, round_keys, counter_bits_le_all_blocks, + num_aes_inputs, (int_aes_encrypt_buffer *)mem_ptr, bsks, + (uint64_t **)ksks); +} + +uint64_t scratch_cuda_integer_key_expansion_256_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_integer_key_expansion_256( + CudaStreams(streams), (int_key_expansion_256_buffer **)mem_ptr, + params, allocate_gpu_memory); +} + +void cuda_integer_key_expansion_256_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + + host_integer_key_expansion_256( + CudaStreams(streams), expanded_keys, key, + (int_key_expansion_256_buffer *)mem_ptr, bsks, + (uint64_t **)ksks); +} + +void cleanup_cuda_integer_key_expansion_256_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + int_key_expansion_256_buffer *mem_ptr = + (int_key_expansion_256_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh new file mode 100644 index 000000000..acb2770da --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh @@ -0,0 +1,355 @@ +#pragma once + +#include "../../include/aes/aes_utilities.h" +#include "../integer/integer.cuh" +#include "../integer/radix_ciphertext.cuh" +#include "../integer/scalar_addition.cuh" +#include "../linearalgebra/addition.cuh" +#include "aes.cuh" + +/** + * The main AES encryption function. It orchestrates the full 14-round AES-256 + * encryption process on the bitsliced state. + * + * The process is broken down into three phases: + * + * 1. Initial Round (Round 0): + * - AddRoundKey, which is a XOR + * + * 2. Main Rounds (Rounds 1-13): + * This sequence is repeated 13 times. + * - SubBytes + * - ShiftRows + * - MixColumns + * - AddRoundKey + * + * 3. Final Round (Round 14): + * - SubBytes + * - ShiftRows + * - AddRoundKey + * + */ +template +__host__ void vectorized_aes_256_encrypt_inplace( + CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced, + CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t STATE_BYTES = 16; + constexpr uint32_t STATE_BITS = STATE_BYTES * BITS_PER_BYTE; + constexpr uint32_t ROUNDS = 14; + + CudaRadixCiphertextFFI *jit_transposed_key = + mem->main_workspaces->initial_states_and_jit_key_workspace; + + CudaRadixCiphertextFFI round_0_key_slice; + as_radix_ciphertext_slice( + &round_0_key_slice, (CudaRadixCiphertextFFI *)round_keys, 0, STATE_BITS); + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI tile_slice; + as_radix_ciphertext_slice( + &tile_slice, mem->main_workspaces->tmp_tiled_key_buffer, + block * STATE_BITS, (block + 1) * STATE_BITS); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &tile_slice, &round_0_key_slice); + } + transpose_blocks_to_bitsliced( + streams.stream(0), streams.gpu_index(0), jit_transposed_key, + mem->main_workspaces->tmp_tiled_key_buffer, num_aes_inputs, STATE_BITS); + + aes_xor(streams, mem, all_states_bitsliced, all_states_bitsliced, + jit_transposed_key); + + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + + for (uint32_t round = 1; round <= ROUNDS; ++round) { + CudaRadixCiphertextFFI s_bits[STATE_BITS]; + for (uint32_t i = 0; i < STATE_BITS; i++) { + as_radix_ciphertext_slice(&s_bits[i], all_states_bitsliced, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + } + + uint32_t sbox_parallelism = mem->sbox_parallel_instances; + switch (sbox_parallelism) { + case 1: + for (uint32_t i = 0; i < STATE_BYTES; ++i) { + CudaRadixCiphertextFFI *sbox_inputs[] = {&s_bits[i * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 1, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 2: + for (uint32_t i = 0; i < STATE_BYTES; i += 2) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 2, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 4: + for (uint32_t i = 0; i < STATE_BYTES; i += 4) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE], + &s_bits[(i + 2) * BITS_PER_BYTE], &s_bits[(i + 3) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 4, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 8: + for (uint32_t i = 0; i < STATE_BYTES; i += 8) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE], + &s_bits[(i + 2) * BITS_PER_BYTE], &s_bits[(i + 3) * BITS_PER_BYTE], + &s_bits[(i + 4) * BITS_PER_BYTE], &s_bits[(i + 5) * BITS_PER_BYTE], + &s_bits[(i + 6) * BITS_PER_BYTE], &s_bits[(i + 7) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 8, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 16: { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[0 * BITS_PER_BYTE], &s_bits[1 * BITS_PER_BYTE], + &s_bits[2 * BITS_PER_BYTE], &s_bits[3 * BITS_PER_BYTE], + &s_bits[4 * BITS_PER_BYTE], &s_bits[5 * BITS_PER_BYTE], + &s_bits[6 * BITS_PER_BYTE], &s_bits[7 * BITS_PER_BYTE], + &s_bits[8 * BITS_PER_BYTE], &s_bits[9 * BITS_PER_BYTE], + &s_bits[10 * BITS_PER_BYTE], &s_bits[11 * BITS_PER_BYTE], + &s_bits[12 * BITS_PER_BYTE], &s_bits[13 * BITS_PER_BYTE], + &s_bits[14 * BITS_PER_BYTE], &s_bits[15 * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 16, num_aes_inputs, + mem, bsks, ksks); + } break; + default: + PANIC("Unsupported S-Box parallelism level selected: %u", + sbox_parallelism); + } + + vectorized_shift_rows(streams, all_states_bitsliced, num_aes_inputs, + mem); + + if (round != ROUNDS) { + vectorized_mix_columns(streams, s_bits, num_aes_inputs, mem, bsks, + ksks); + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + } + + CudaRadixCiphertextFFI round_key_slice; + as_radix_ciphertext_slice( + &round_key_slice, (CudaRadixCiphertextFFI *)round_keys, + round * STATE_BITS, (round + 1) * STATE_BITS); + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI tile_slice; + as_radix_ciphertext_slice( + &tile_slice, mem->main_workspaces->tmp_tiled_key_buffer, + block * STATE_BITS, (block + 1) * STATE_BITS); + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &tile_slice, + &round_key_slice); + } + transpose_blocks_to_bitsliced( + streams.stream(0), streams.gpu_index(0), jit_transposed_key, + mem->main_workspaces->tmp_tiled_key_buffer, num_aes_inputs, STATE_BITS); + + aes_xor(streams, mem, all_states_bitsliced, all_states_bitsliced, + jit_transposed_key); + + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + } +} + +/** + * Top-level function to perform a full AES-256-CTR encryption homomorphically. + * + * +----------+ +-------------------+ + * | IV_CT | | Plaintext Counter | + * +----------+ +-------------------+ + * | | + * V V + * +---------------------------------+ + * | Homomorphic Full Adder | + * | (IV_CT + Counter) | + * +---------------------------------+ + * | + * V + * +---------------------------------+ + * | Homomorphic AES Encryption | -> Final Output Ciphertext + * | (14 Rounds) | + * +---------------------------------+ + * + */ +template +__host__ void host_integer_aes_ctr_256_encrypt( + CudaStreams streams, CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys, + const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t NUM_BITS = 128; + + CudaRadixCiphertextFFI *initial_states = + mem->main_workspaces->initial_states_and_jit_key_workspace; + + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI output_slice; + as_radix_ciphertext_slice(&output_slice, initial_states, + block * NUM_BITS, (block + 1) * NUM_BITS); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &output_slice, iv); + } + + CudaRadixCiphertextFFI *transposed_states = + mem->main_workspaces->main_bitsliced_states_buffer; + transpose_blocks_to_bitsliced(streams.stream(0), streams.gpu_index(0), + transposed_states, initial_states, + num_aes_inputs, NUM_BITS); + + vectorized_aes_full_adder_inplace(streams, transposed_states, + counter_bits_le_all_blocks, + num_aes_inputs, mem, bsks, ksks); + + vectorized_aes_256_encrypt_inplace( + streams, transposed_states, round_keys, num_aes_inputs, mem, bsks, ksks); + + transpose_bitsliced_to_blocks(streams.stream(0), streams.gpu_index(0), + output, transposed_states, + num_aes_inputs, NUM_BITS); +} + +template +uint64_t scratch_cuda_integer_key_expansion_256( + CudaStreams streams, int_key_expansion_256_buffer **mem_ptr, + int_radix_params params, bool allocate_gpu_memory) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_key_expansion_256_buffer( + streams, params, allocate_gpu_memory, size_tracker); + return size_tracker; +} + +/** + * Homomorphically performs the AES-256 key expansion schedule on the GPU. + * + * This function expands an encrypted 256-bit key into 60 words (15 round keys). + * The generation logic for a new word `w_i` depends on its position (with + * KEY_WORDS = 8): + * - If (i % 8 == 0): w_i = w_{i-8} + SubWord(RotWord(w_{i-1})) + Rcon[i/8] + * - If (i % 8 == 4): w_i = w_{i-8} + SubWord(w_{i-1}) + * - Otherwise: w_i = w_{i-8} + w_{i-1} + */ +template +__host__ void host_integer_key_expansion_256( + CudaStreams streams, CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, int_key_expansion_256_buffer *mem, + void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t BITS_PER_WORD = 32; + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t BYTES_PER_WORD = 4; + constexpr uint32_t TOTAL_WORDS = 60; + constexpr uint32_t KEY_WORDS = 8; + + const Torus rcon[] = {0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36}; + + CudaRadixCiphertextFFI *words = mem->words_buffer; + + CudaRadixCiphertextFFI initial_key_dest_slice; + as_radix_ciphertext_slice(&initial_key_dest_slice, words, 0, + KEY_WORDS * BITS_PER_WORD); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &initial_key_dest_slice, key); + + for (uint32_t w = KEY_WORDS; w < TOTAL_WORDS; ++w) { + CudaRadixCiphertextFFI tmp_word_buffer, tmp_far, tmp_near; + + as_radix_ciphertext_slice(&tmp_word_buffer, mem->tmp_word_buffer, 0, + BITS_PER_WORD); + as_radix_ciphertext_slice(&tmp_far, words, (w - 8) * BITS_PER_WORD, + (w - 7) * BITS_PER_WORD); + as_radix_ciphertext_slice(&tmp_near, words, (w - 1) * BITS_PER_WORD, + w * BITS_PER_WORD); + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &tmp_word_buffer, &tmp_near); + + if (w % KEY_WORDS == 0) { + CudaRadixCiphertextFFI rotated_word_buffer; + as_radix_ciphertext_slice( + &rotated_word_buffer, mem->tmp_rotated_word_buffer, 0, BITS_PER_WORD); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), &rotated_word_buffer, 0, + BITS_PER_WORD - BITS_PER_BYTE, &tmp_word_buffer, BITS_PER_BYTE, + BITS_PER_WORD); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), &rotated_word_buffer, + BITS_PER_WORD - BITS_PER_BYTE, BITS_PER_WORD, &tmp_word_buffer, 0, + BITS_PER_BYTE); + + CudaRadixCiphertextFFI bit_slices[BITS_PER_WORD]; + for (uint32_t i = 0; i < BITS_PER_WORD; ++i) { + as_radix_ciphertext_slice(&bit_slices[i], &rotated_word_buffer, + i, i + 1); + } + + CudaRadixCiphertextFFI *sbox_byte_pointers[BYTES_PER_WORD]; + for (uint32_t i = 0; i < BYTES_PER_WORD; ++i) { + sbox_byte_pointers[i] = &bit_slices[i * BITS_PER_BYTE]; + } + + vectorized_sbox_n_bytes(streams, sbox_byte_pointers, + BYTES_PER_WORD, 1, mem->aes_encrypt_buffer, + bsks, ksks); + + Torus rcon_val = rcon[w / KEY_WORDS - 1]; + for (uint32_t bit = 0; bit < BITS_PER_BYTE; ++bit) { + if ((rcon_val >> (7 - bit)) & 1) { + CudaRadixCiphertextFFI first_byte_bit_slice; + as_radix_ciphertext_slice(&first_byte_bit_slice, + &rotated_word_buffer, bit, bit + 1); + host_add_scalar_one_inplace(streams, &first_byte_bit_slice, + mem->params.message_modulus, + mem->params.carry_modulus); + } + } + + aes_flush_inplace(streams, &rotated_word_buffer, mem->aes_encrypt_buffer, + bsks, ksks); + + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &tmp_word_buffer, + &rotated_word_buffer); + } else if (w % KEY_WORDS == 4) { + CudaRadixCiphertextFFI bit_slices[BITS_PER_WORD]; + for (uint32_t i = 0; i < BITS_PER_WORD; ++i) { + as_radix_ciphertext_slice(&bit_slices[i], &tmp_word_buffer, i, + i + 1); + } + + CudaRadixCiphertextFFI *sbox_byte_pointers[BYTES_PER_WORD]; + for (uint32_t i = 0; i < BYTES_PER_WORD; ++i) { + sbox_byte_pointers[i] = &bit_slices[i * BITS_PER_BYTE]; + } + + vectorized_sbox_n_bytes(streams, sbox_byte_pointers, + BYTES_PER_WORD, 1, mem->aes_encrypt_buffer, + bsks, ksks); + } + + aes_xor(streams, mem->aes_encrypt_buffer, &tmp_word_buffer, &tmp_far, + &tmp_word_buffer); + aes_flush_inplace(streams, &tmp_word_buffer, mem->aes_encrypt_buffer, bsks, + ksks); + + CudaRadixCiphertextFFI dest_word; + as_radix_ciphertext_slice(&dest_word, words, w * BITS_PER_WORD, + (w + 1) * BITS_PER_WORD); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_word, &tmp_word_buffer); + } + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + expanded_keys, words); +} diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index 87e14ffe4..448bfb346 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1854,6 +1854,54 @@ unsafe extern "C" { mem_ptr_void: *mut *mut i8, ); } +unsafe extern "C" { + pub fn cuda_integer_aes_ctr_256_encrypt_64( + streams: CudaStreamsFFI, + output: *mut CudaRadixCiphertextFFI, + iv: *const CudaRadixCiphertextFFI, + round_keys: *const CudaRadixCiphertextFFI, + counter_bits_le_all_blocks: *const u64, + num_aes_inputs: u32, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn scratch_cuda_integer_key_expansion_256_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_integer_key_expansion_256_64( + streams: CudaStreamsFFI, + expanded_keys: *mut CudaRadixCiphertextFFI, + key: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_integer_key_expansion_256_64( + streams: CudaStreamsFFI, + mem_ptr_void: *mut *mut i8, + ); +} pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0; pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1; pub type KS_TYPE = ffi::c_uint; diff --git a/tfhe-benchmark/benches/integer/aes.rs b/tfhe-benchmark/benches/integer/aes.rs index 4fa475e2c..a842f1a12 100644 --- a/tfhe-benchmark/benches/integer/aes.rs +++ b/tfhe-benchmark/benches/integer/aes.rs @@ -1,8 +1,8 @@ #[cfg(feature = "gpu")] pub mod cuda { use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; - use benchmark::utilities::{get_bench_type, write_to_json, BenchmarkType, OperatorType}; - use criterion::{black_box, Criterion, Throughput}; + use benchmark::utilities::{write_to_json, OperatorType}; + use criterion::{black_box, Criterion}; use tfhe::core_crypto::gpu::CudaStreams; use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use tfhe::integer::gpu::CudaServerKey; @@ -29,114 +29,109 @@ pub mod cuda { let param_name = param.name(); - match get_bench_type() { - BenchmarkType::Latency => { - let streams = CudaStreams::new_multi_gpu(); - let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); - let sks = CudaServerKey::new(&cpu_cks, &streams); - let cks = RadixClientKey::from((cpu_cks, 1)); + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); - let ct_key = cks.encrypt_u128_for_aes_ctr(key); - let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + let ct_key = cks.encrypt_u128_for_aes_ctr(key); + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); - let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); - let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); - { - const NUM_AES_INPUTS: usize = 1; - const SBOX_PARALLELISM: usize = 16; - let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption"); + { + const NUM_AES_INPUTS: usize = 1; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption"); - let round_keys = sks.key_expansion(&d_key, &streams); + let round_keys = sks.key_expansion(&d_key, &streams); - bench_group.bench_function(&bench_id, |b| { - b.iter(|| { - black_box(sks.aes_encrypt( - &d_iv, - &round_keys, - 0, - NUM_AES_INPUTS, - SBOX_PARALLELISM, - &streams, - )); - }) - }); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.aes_encrypt( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + }) + }); - write_to_json::( - &bench_id, - atomic_param, - param.name(), - "aes_encryption", - &OperatorType::Atomic, - aes_op_bit_size, - vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], - ); - } + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_encryption", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } - { - let bench_id = format!("{param_name}::key_expansion"); + { + let bench_id = format!("{param_name}::key_expansion"); - bench_group.bench_function(&bench_id, |b| { - b.iter(|| { - black_box(sks.key_expansion(&d_key, &streams)); - }) - }); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.key_expansion(&d_key, &streams)); + }) + }); - write_to_json::( - &bench_id, - atomic_param, - param.name(), - "aes_key_expansion", - &OperatorType::Atomic, - aes_op_bit_size, - vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], - ); - } - } - BenchmarkType::Throughput => { - const NUM_AES_INPUTS: usize = 192; - const SBOX_PARALLELISM: usize = 16; - let bench_id = format!("throughput::{param_name}::{NUM_AES_INPUTS}_inputs"); + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_key_expansion", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } - let streams = CudaStreams::new_multi_gpu(); - let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); - let sks = CudaServerKey::new(&cpu_cks, &streams); - let cks = RadixClientKey::from((cpu_cks, 1)); + { + const NUM_AES_INPUTS: usize = 192; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_inputs_encryption"); - bench_group.throughput(Throughput::Elements(NUM_AES_INPUTS as u64)); + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); - let ct_key = cks.encrypt_u128_for_aes_ctr(key); - let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + let ct_key = cks.encrypt_u128_for_aes_ctr(key); + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); - let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); - let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); - let round_keys = sks.key_expansion(&d_key, &streams); + let round_keys = sks.key_expansion(&d_key, &streams); - bench_group.bench_function(&bench_id, |b| { - b.iter(|| { - black_box(sks.aes_encrypt( - &d_iv, - &round_keys, - 0, - NUM_AES_INPUTS, - SBOX_PARALLELISM, - &streams, - )); - }) - }); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.aes_encrypt( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + }) + }); - write_to_json::( - &bench_id, - atomic_param, - param.name(), - "aes_encryption", - &OperatorType::Atomic, - aes_op_bit_size, - vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], - ); - } - }; + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_encryption", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } bench_group.finish(); } diff --git a/tfhe-benchmark/benches/integer/aes256.rs b/tfhe-benchmark/benches/integer/aes256.rs new file mode 100644 index 000000000..60f88f0b6 --- /dev/null +++ b/tfhe-benchmark/benches/integer/aes256.rs @@ -0,0 +1,143 @@ +#[cfg(feature = "gpu")] +pub mod cuda { + use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + use benchmark::utilities::{write_to_json, OperatorType}; + use criterion::{black_box, Criterion}; + use tfhe::core_crypto::gpu::CudaStreams; + use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + use tfhe::integer::gpu::CudaServerKey; + use tfhe::integer::keycache::KEY_CACHE; + use tfhe::integer::{IntegerKeyKind, RadixClientKey}; + use tfhe::keycache::NamedParam; + use tfhe::shortint::AtomicPatternParameters; + + pub fn cuda_aes_256(c: &mut Criterion) { + let bench_name = "integer::cuda::aes_256"; + + let mut bench_group = c.benchmark_group(bench_name); + bench_group + .sample_size(15) + .measurement_time(std::time::Duration::from_secs(60)) + .warm_up_time(std::time::Duration::from_secs(60)); + + let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let atomic_param: AtomicPatternParameters = param.into(); + + let key_hi: u128 = 0x603deb1015ca71be2b73aef0857d7781; + let key_lo: u128 = 0x1f352c073b6108d72d9810a30914dff4; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + + let aes_block_op_bit_size = 128; + let aes_key_op_bit_size = 256; + + let param_name = param.name(); + + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); + + let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo); + + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + + { + const NUM_AES_INPUTS: usize = 1; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption"); + + let round_keys = sks.key_expansion_256(&d_key, &streams); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.aes_256_encrypt( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_256_encryption", + &OperatorType::Atomic, + aes_block_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_block_op_bit_size as usize], + ); + } + + { + let bench_id = format!("{param_name}::key_expansion"); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.key_expansion_256(&d_key, &streams)); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_256_key_expansion", + &OperatorType::Atomic, + aes_key_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_key_op_bit_size as usize], + ); + } + + { + const NUM_AES_INPUTS: usize = 192; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_inputs_encryption"); + + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); + + let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo); + + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + + let round_keys = sks.key_expansion_256(&d_key, &streams); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.aes_256_encrypt( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_256_encryption", + &OperatorType::Atomic, + aes_block_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_block_op_bit_size as usize], + ); + } + + bench_group.finish(); + } +} diff --git a/tfhe-benchmark/benches/integer/bench.rs b/tfhe-benchmark/benches/integer/bench.rs index 9635df14e..ae565828c 100644 --- a/tfhe-benchmark/benches/integer/bench.rs +++ b/tfhe-benchmark/benches/integer/bench.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] mod aes; +mod aes256; mod oprf; mod rerand; @@ -2799,6 +2800,7 @@ mod cuda { cuda_ilog2, oprf::cuda::cuda_unsigned_oprf, aes::cuda::cuda_aes, + aes256::cuda::cuda_aes_256, ); criterion_group!( @@ -2828,6 +2830,7 @@ mod cuda { cuda_scalar_rem, oprf::cuda::cuda_unsigned_oprf, aes::cuda::cuda_aes, + aes256::cuda::cuda_aes_256, ); criterion_group!( diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index ad89958aa..412872c9e 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -7627,6 +7627,104 @@ pub(crate) unsafe fn cuda_backend_unchecked_aes_ctr_encrypt( + streams: &CudaStreams, + output: &mut CudaRadixCiphertext, + iv: &CudaRadixCiphertext, + round_keys: &CudaRadixCiphertext, + start_counter: u128, + num_aes_inputs: u32, + sbox_parallelism: u32, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let mut cuda_ffi_output = + prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels); + + let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels); + + let mut round_keys_degrees = round_keys.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut round_keys_noise_levels = round_keys + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let cuda_ffi_round_keys = prepare_cuda_radix_ffi( + round_keys, + &mut round_keys_degrees, + &mut round_keys_noise_levels, + ); + + let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); + + let counter_bits_le: Vec = (0..num_aes_inputs) + .flat_map(|i| { + let current_counter = start_counter + i as u128; + (0..128).map(move |bit_index| ((current_counter >> bit_index) & 1) as u64) + }) + .collect(); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_aes_encrypt_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + num_aes_inputs, + sbox_parallelism, + ); + + cuda_integer_aes_ctr_256_encrypt_64( + streams.ffi(), + &raw mut cuda_ffi_output, + &raw const cuda_ffi_iv, + &raw const cuda_ffi_round_keys, + counter_bits_le.as_ptr(), + num_aes_inputs, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_integer_aes_encrypt_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(output, &cuda_ffi_output); +} + #[allow(clippy::too_many_arguments)] pub(crate) fn cuda_backend_get_aes_ctr_encrypt_size_on_gpu( streams: &CudaStreams, @@ -7802,3 +7900,131 @@ pub(crate) fn cuda_backend_get_aes_key_expansion_size_on_gpu( size } + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_aes_key_expansion_256( + streams: &CudaStreams, + expanded_keys: &mut CudaRadixCiphertext, + key: &CudaRadixCiphertext, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let mut expanded_keys_degrees = expanded_keys + .info + .blocks + .iter() + .map(|b| b.degree.0) + .collect(); + let mut expanded_keys_noise_levels = expanded_keys + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut cuda_ffi_expanded_keys = prepare_cuda_radix_ffi( + expanded_keys, + &mut expanded_keys_degrees, + &mut expanded_keys_noise_levels, + ); + + let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels); + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_key_expansion_256_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_integer_key_expansion_256_64( + streams.ffi(), + &raw mut cuda_ffi_expanded_keys, + &raw const cuda_ffi_key, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_integer_key_expansion_256_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(expanded_keys, &cuda_ffi_expanded_keys); +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn cuda_backend_get_aes_key_expansion_256_size_on_gpu( + streams: &CudaStreams, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) -> u64 { + let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + let size = unsafe { + scratch_cuda_integer_key_expansion_256_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ) + }; + + unsafe { + cleanup_cuda_integer_key_expansion_256_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)) + }; + + size +} diff --git a/tfhe/src/integer/gpu/server_key/radix/aes.rs b/tfhe/src/integer/gpu/server_key/radix/aes.rs index 4d3c89500..e7f1691bc 100644 --- a/tfhe/src/integer/gpu/server_key/radix/aes.rs +++ b/tfhe/src/integer/gpu/server_key/radix/aes.rs @@ -104,6 +104,59 @@ impl RadixClientKey { } impl CudaServerKey { + /// Computes homomorphically AES-128 encryption in CTR mode. + /// + /// This function performs AES-128 encryption on an encrypted 128-bit IV + /// using an encrypted 128-bit key. It operates in Counter (CTR) mode, generating + /// `num_aes_inputs` encrypted ciphertexts starting from the `start_counter` value + /// (which is typically added to the IV). + /// + /// The key and IV must be prepared using `encrypt_u128_for_aes_ctr`, which + /// encrypts each of the 128 bits individually. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::GpuIndex; + /// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + /// + /// let gpu_index = 0; + /// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index)); + /// + /// // Generate the client key and the server key: + /// // AES bit-wise operations require 1-block ciphertexts (for encrypting single bits). + /// let num_blocks = 1; + /// let (cks, sks) = gen_keys_radix_gpu( + /// PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + /// num_blocks, + /// &streams, + /// ); + /// + /// let key: u128 = 0x2b7e151628aed2a6abf7158809cf4f3c; + /// let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + /// let num_aes_inputs = 2; // Produce 2 128-bits ciphertexts + /// let start_counter = 0u128; + /// + /// // Encrypt the 128-bit key and IV bit by bit + /// let ct_key = cks.encrypt_u128_for_aes_ctr(key); + /// let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + /// + /// let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + /// let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + /// + /// let d_ct_res = sks.aes_ctr(&d_key, &d_iv, start_counter, num_aes_inputs, &streams); + /// + /// let ct_res = d_ct_res.to_radix_ciphertext(&streams); + /// + /// let fhe_results = cks.decrypt_u128_from_aes_ctr(&ct_res, num_aes_inputs); + /// + /// // Verify: + /// let expected_results = vec![0xec8cdf7398607cb0f2d21675ea9ea1e4, 0x362b7c3c6773516318a077d7fc5073ae]; + /// assert_eq!(fhe_results, expected_results); + /// ``` pub fn aes_ctr( &self, key: &CudaUnsignedRadixCiphertext, diff --git a/tfhe/src/integer/gpu/server_key/radix/aes256.rs b/tfhe/src/integer/gpu/server_key/radix/aes256.rs new file mode 100644 index 000000000..05970830e --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/aes256.rs @@ -0,0 +1,364 @@ +use crate::core_crypto::gpu::{ + check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams, +}; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; + +use crate::core_crypto::prelude::LweBskGroupingFactor; +use crate::integer::gpu::{ + cuda_backend_aes_key_expansion_256, cuda_backend_get_aes_key_expansion_256_size_on_gpu, + cuda_backend_unchecked_aes_ctr_256_encrypt, PBSType, +}; +use crate::integer::{RadixCiphertext, RadixClientKey}; + +const NUM_BITS: usize = 128; + +impl RadixClientKey { + pub fn encrypt_2u128_for_aes_ctr_256(&self, key_hi: u128, key_lo: u128) -> RadixCiphertext { + let ctxt_hi = self.encrypt_u128_for_aes_ctr(key_hi); + let ctxt_lo = self.encrypt_u128_for_aes_ctr(key_lo); + + let mut combined_blocks = ctxt_hi.blocks; + combined_blocks.extend(ctxt_lo.blocks); + + RadixCiphertext::from(combined_blocks) + } +} + +impl CudaServerKey { + /// Computes homomorphically AES-256 encryption in CTR mode. + /// + /// This function performs AES-256 encryption on an encrypted 128-bit IV + /// using an encrypted 256-bit key. It operates in Counter (CTR) mode, generating + /// `num_aes_inputs` encrypted ciphertexts starting from the `start_counter` value + /// (which is typically added to the IV). + /// + /// The 256-bit key must be prepared using `encrypt_2u128_for_aes_ctr_256` and + /// the 128-bit IV using `encrypt_u128_for_aes_ctr`. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::GpuIndex; + /// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + /// + /// let gpu_index = 0; + /// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index)); + /// + /// // Generate the client key and the server key: + /// // AES bit-wise operations require 1-block ciphertexts (for encrypting single bits). + /// let num_blocks = 1; + /// let (cks, sks) = gen_keys_radix_gpu( + /// PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + /// num_blocks, + /// &streams, + /// ); + /// + /// let key_hi: u128 = 0x603deb1015ca71be2b73aef0857d7781; + /// let key_lo: u128 = 0x1f352c073b6108d72d9810a30914dff4; + /// let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + /// let num_aes_inputs = 2; // Produce 2 128-bits ciphertexts + /// let start_counter = 0u128; + /// + /// // Encrypt the 256-bit key and 128-bit IV bit by bit + /// let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo); + /// let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + /// + /// let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + /// let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + /// + /// let d_ct_res = sks.aes_ctr_256(&d_key, &d_iv, start_counter, num_aes_inputs, &streams); + /// + /// let ct_res = d_ct_res.to_radix_ciphertext(&streams); + /// + /// let fhe_results = cks.decrypt_u128_from_aes_ctr(&ct_res, num_aes_inputs); + /// + /// // Verify: + /// let expected_results: Vec = vec![ + /// 0xbdf7df1591716335e9a8b15c860c502, + /// 0x5a6e699d536119065433863c8f657b94, + /// ]; + /// assert_eq!(fhe_results, expected_results); + /// ``` + pub fn aes_ctr_256( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let gpu_index = streams.gpu_indexes[0]; + + let key_expansion_size = self.get_key_expansion_256_size_on_gpu(streams); + check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index); + + // `parallelism` refers to level of parallelization of the S-box. + // S-box should process 16 bytes of data: sequentially, or in groups of 2, + // or in groups of 4, or in groups of 8, or all 16 at the same time. + // More parallelization leads to higher memory usage. Therefore, we must find a way + // to maximize parallelization while ensuring that there is still enough memory remaining on + // the GPU. + // + let mut parallelism = 16; + + while parallelism > 0 { + // `num_aes_inputs` refers to the number of 128-bit ciphertexts that AES will produce. + // + let aes_encrypt_size = + self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams); + + if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) { + let round_keys = self.key_expansion_256(key, streams); + let res = self.aes_256_encrypt( + iv, + &round_keys, + start_counter, + num_aes_inputs, + parallelism, + streams, + ); + return res; + } + parallelism /= 2; + } + + panic!("Failed to allocate GPU memory for AES, even with the lowest parallelism setting."); + } + + pub fn aes_ctr_256_with_fixed_parallelism( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + assert!( + [1, 2, 4, 8, 16].contains(&sbox_parallelism), + "Invalid S-Box parallelism: must be one of [1, 2, 4, 8, 16], got {sbox_parallelism}" + ); + + let gpu_index = streams.gpu_indexes[0]; + + let key_expansion_size = self.get_key_expansion_256_size_on_gpu(streams); + check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index); + + let aes_encrypt_size = + self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams); + check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index); + + let round_keys = self.key_expansion_256(key, streams); + self.aes_256_encrypt( + iv, + &round_keys, + start_counter, + num_aes_inputs, + sbox_parallelism, + streams, + ) + } + + pub fn aes_256_encrypt( + &self, + iv: &CudaUnsignedRadixCiphertext, + round_keys: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let mut result: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_aes_inputs * 128, streams); + + let num_round_key_blocks = 15 * NUM_BITS; + + assert_eq!( + iv.as_ref().d_blocks.lwe_ciphertext_count().0, + NUM_BITS, + "AES IV must contain {NUM_BITS} encrypted bits, but contains {}", + iv.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + round_keys.as_ref().d_blocks.lwe_ciphertext_count().0, + num_round_key_blocks, + "AES round_keys must contain {num_round_key_blocks} encrypted bits, but contains {}", + round_keys.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + result.as_ref().d_blocks.lwe_ciphertext_count().0, + num_aes_inputs * 128, + "AES result must contain {} encrypted bits for {num_aes_inputs} blocks, but contains {}", + num_aes_inputs * 128, + result.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_aes_ctr_256_encrypt( + streams, + result.as_mut(), + iv.as_ref(), + round_keys.as_ref(), + start_counter, + num_aes_inputs as u32, + sbox_parallelism as u32, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_aes_ctr_256_encrypt( + streams, + result.as_mut(), + iv.as_ref(), + round_keys.as_ref(), + start_counter, + num_aes_inputs as u32, + sbox_parallelism as u32, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ); + } + } + } + result + } + + pub fn key_expansion_256( + &self, + key: &CudaUnsignedRadixCiphertext, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let num_round_keys = 15; + let input_key_bits = 256; + let round_key_bits = 128; + + let mut expanded_keys: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_round_keys * round_key_bits, streams); + + assert_eq!( + key.as_ref().d_blocks.lwe_ciphertext_count().0, + input_key_bits, + "Input key must contain {} encrypted bits, but contains {}", + input_key_bits, + key.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_aes_key_expansion_256( + streams, + expanded_keys.as_mut(), + key.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_aes_key_expansion_256( + streams, + expanded_keys.as_mut(), + key.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ); + } + } + } + expanded_keys + } + + pub fn get_key_expansion_256_size_on_gpu(&self, streams: &CudaStreams) -> u64 { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_get_aes_key_expansion_256_size_on_gpu( + streams, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ) + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_get_aes_key_expansion_256_size_on_gpu( + streams, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ) + } + } + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index a66777dcf..b4ccf6958 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -59,6 +59,7 @@ mod vector_comparisons; mod vector_find; mod aes; +mod aes256; #[cfg(test)] mod tests_long_run; #[cfg(test)] diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index 755cfa6ff..ed7fbe931 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod test_add; pub(crate) mod test_aes; +pub(crate) mod test_aes256; pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs index 6d7e72486..804226e2a 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs @@ -7,23 +7,23 @@ use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ aes_fixed_parallelism_2_inputs_test, }; use crate::shortint::parameters::{ - TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }; create_gpu_parameterized_test!(integer_aes_fixed_parallelism_1_input { - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 }); create_gpu_parameterized_test!(integer_aes_fixed_parallelism_2_inputs { - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 }); create_gpu_parameterized_test!(integer_aes_dynamic_parallelism_many_inputs { - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, - PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 }); // The following two tests are referred to as "fixed_parallelism" because the objective is to test diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes256.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes256.rs new file mode 100644 index 000000000..c75880d91 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes256.rs @@ -0,0 +1,59 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parameterized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + aes_256_dynamic_parallelism_many_inputs_test, aes_256_fixed_parallelism_1_input_test, + aes_256_fixed_parallelism_2_inputs_test, +}; +use crate::shortint::parameters::{ + TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, +}; + +create_gpu_parameterized_test!(integer_aes_256_fixed_parallelism_1_input { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_aes_256_fixed_parallelism_2_inputs { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_aes_256_dynamic_parallelism_many_inputs { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +// The following two tests are referred to as "fixed_parallelism" because the objective is to test +// AES, in CTR mode, across all possible parallelizations of the S-box. The S-box must process 16 +// bytes; the parallelization refers to the number of bytes it will process in parallel in one call: +// 1, 2, 4, 8, or 16. +// +fn integer_aes_256_fixed_parallelism_1_input

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr_256_with_fixed_parallelism); + aes_256_fixed_parallelism_1_input_test(param, executor); +} + +fn integer_aes_256_fixed_parallelism_2_inputs

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr_256_with_fixed_parallelism); + aes_256_fixed_parallelism_2_inputs_test(param, executor); +} + +// The test referred to as "dynamic_parallelism" will seek the maximum s-box parallelization that +// the machine can support. +// +fn integer_aes_256_dynamic_parallelism_many_inputs

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr_256); + aes_256_dynamic_parallelism_many_inputs_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index b597c076b..36b3bb3f5 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -48,6 +48,11 @@ pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_ aes_fixed_parallelism_2_inputs_test, }; #[cfg(feature = "gpu")] +pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_aes256::{ + aes_256_dynamic_parallelism_many_inputs_test, aes_256_fixed_parallelism_1_input_test, + aes_256_fixed_parallelism_2_inputs_test, +}; +#[cfg(feature = "gpu")] pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::default_neg_test; pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::unchecked_neg_test; #[cfg(feature = "gpu")] diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 4b76a7fae..176686baa 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -1,6 +1,7 @@ mod modulus_switch_compression; pub(crate) mod test_add; pub(crate) mod test_aes; +pub(crate) mod test_aes256; pub(crate) mod test_bitwise_op; mod test_block_rotate; mod test_block_shift; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs index 86f3bd731..a6c38f468 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs @@ -216,7 +216,7 @@ where let ctxt_key = cks.encrypt_u128_for_aes_ctr(key); let ctxt_iv = cks.encrypt_u128_for_aes_ctr(iv); - for num_aes_inputs in [4, 8, 16, 32] { + for num_aes_inputs in [4, 8, 16] { let plain_results = plain_aes_ctr(num_aes_inputs, iv, key); let encrypted_result = executor.execute((&ctxt_key, &ctxt_iv, 0, num_aes_inputs)); let fhe_results = cks.decrypt_u128_from_aes_ctr(&encrypted_result, num_aes_inputs); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes256.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes256.rs new file mode 100644 index 000000000..bd9cb7b89 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes256.rs @@ -0,0 +1,242 @@ +#![cfg(feature = "gpu")] + +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey}; +use crate::shortint::parameters::TestParameters; +use std::sync::Arc; + +const S_BOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +fn plain_key_expansion(key_hi: u128, key_lo: u128) -> Vec { + const RCON: [u32; 10] = [ + 0x01000000, 0x02000000, 0x04000000, 0x08000000, 0x10000000, 0x20000000, 0x40000000, + 0x80000000, 0x1B000000, 0x36000000, + ]; + let mut words = [0u32; 60]; + for (i, word) in words.iter_mut().enumerate().take(4) { + *word = (key_hi >> (96 - (i * 32))) as u32; + } + for (i, word) in words.iter_mut().enumerate().skip(4).take(4) { + *word = (key_lo >> (96 - ((i - 4) * 32))) as u32; + } + for i in 8..60 { + let mut temp = words[i - 1]; + if i % 8 == 0 { + temp = temp.rotate_left(8); + let mut sub_bytes = 0u32; + for j in 0..4 { + let byte = (temp >> (24 - j * 8)) as u8; + sub_bytes |= (S_BOX[byte as usize] as u32) << (24 - j * 8); + } + temp = sub_bytes ^ RCON[i / 8 - 1]; + } else if i % 8 == 4 { + let mut sub_bytes = 0u32; + for j in 0..4 { + let byte = (temp >> (24 - j * 8)) as u8; + sub_bytes |= (S_BOX[byte as usize] as u32) << (24 - j * 8); + } + temp = sub_bytes; + } + words[i] = words[i - 8] ^ temp; + } + words + .chunks_exact(4) + .map(|chunk| { + ((chunk[0] as u128) << 96) + | ((chunk[1] as u128) << 64) + | ((chunk[2] as u128) << 32) + | (chunk[3] as u128) + }) + .collect() +} +fn sub_bytes(state: &mut [u8; 16]) { + for byte in state.iter_mut() { + *byte = S_BOX[*byte as usize]; + } +} +fn shift_rows(state: &mut [u8; 16]) { + let original = *state; + state[1] = original[5]; + state[5] = original[9]; + state[9] = original[13]; + state[13] = original[1]; + state[2] = original[10]; + state[6] = original[14]; + state[10] = original[2]; + state[14] = original[6]; + state[3] = original[15]; + state[7] = original[3]; + state[11] = original[7]; + state[15] = original[11]; +} +fn gmul(mut a: u8, mut b: u8) -> u8 { + let mut p = 0; + for _ in 0..8 { + if (b & 1) != 0 { + p ^= a; + } + let hi_bit_set = (a & 0x80) != 0; + a <<= 1; + if hi_bit_set { + a ^= 0x1B; + } + b >>= 1; + } + p +} +fn mix_columns(state: &mut [u8; 16]) { + let original = *state; + for i in 0..4 { + let col = i * 4; + state[col] = gmul(original[col], 2) + ^ gmul(original[col + 1], 3) + ^ original[col + 2] + ^ original[col + 3]; + state[col + 1] = original[col] + ^ gmul(original[col + 1], 2) + ^ gmul(original[col + 2], 3) + ^ original[col + 3]; + state[col + 2] = original[col] + ^ original[col + 1] + ^ gmul(original[col + 2], 2) + ^ gmul(original[col + 3], 3); + state[col + 3] = gmul(original[col], 3) + ^ original[col + 1] + ^ original[col + 2] + ^ gmul(original[col + 3], 2); + } +} +fn add_round_key(state: &mut [u8; 16], round_key: u128) { + let key_bytes = round_key.to_be_bytes(); + for i in 0..16 { + state[i] ^= key_bytes[i]; + } +} +fn plain_aes_encrypt_block(block_bytes: &mut [u8; 16], expanded_keys: &[u128]) { + add_round_key(block_bytes, expanded_keys[0]); + for round_key in expanded_keys.iter().take(14).skip(1) { + sub_bytes(block_bytes); + shift_rows(block_bytes); + mix_columns(block_bytes); + add_round_key(block_bytes, *round_key); + } + sub_bytes(block_bytes); + shift_rows(block_bytes); + add_round_key(block_bytes, expanded_keys[14]); +} +fn plain_aes_ctr(num_aes_inputs: usize, iv: u128, key_hi: u128, key_lo: u128) -> Vec { + let expanded_keys = plain_key_expansion(key_hi, key_lo); + let mut results = Vec::with_capacity(num_aes_inputs); + for i in 0..num_aes_inputs { + let counter_value = iv.wrapping_add(i as u128); + let mut block = counter_value.to_be_bytes(); + plain_aes_encrypt_block(&mut block, &expanded_keys); + results.push(u128::from_be_bytes(block)); + } + results +} + +fn internal_aes_fixed_parallelism_test(param: P, mut executor: E, num_aes_inputs: usize) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key_hi: u128 = 0x603deb1015ca71be2b73aef0857d7781; + let key_lo: u128 = 0x1f352c073b6108d72d9810a30914dff4; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + + let plain_results = plain_aes_ctr(num_aes_inputs, iv, key_hi, key_lo); + + let ctxt_hi = cks.encrypt_u128_for_aes_ctr(key_hi); + let ctxt_lo = cks.encrypt_u128_for_aes_ctr(key_lo); + let mut key_blocks = ctxt_hi.blocks; + key_blocks.extend(ctxt_lo.blocks); + let ctxt_key = RadixCiphertext::from(key_blocks); + + let ctxt_iv = cks.encrypt_u128_for_aes_ctr(iv); + + for sbox_parallelism in [1, 2, 4, 8, 16] { + let encrypted_result = + executor.execute((&ctxt_key, &ctxt_iv, 0, num_aes_inputs, sbox_parallelism)); + let fhe_results = cks.decrypt_u128_from_aes_ctr(&encrypted_result, num_aes_inputs); + assert_eq!(fhe_results, plain_results); + } +} + +pub fn aes_256_fixed_parallelism_1_input_test(param: P, executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + internal_aes_fixed_parallelism_test(param, executor, 1); +} + +pub fn aes_256_fixed_parallelism_2_inputs_test(param: P, executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + internal_aes_fixed_parallelism_test(param, executor, 2); +} + +pub fn aes_256_dynamic_parallelism_many_inputs_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize), + RadixCiphertext, + >, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key_hi: u128 = 0x603deb1015ca71be2b73aef0857d7781; + let key_lo: u128 = 0x1f352c073b6108d72d9810a30914dff4; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + + let ctxt_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo); + let ctxt_iv = cks.encrypt_u128_for_aes_ctr(iv); + + for num_aes_inputs in [4, 8, 16] { + let plain_results = plain_aes_ctr(num_aes_inputs, iv, key_hi, key_lo); + let encrypted_result = executor.execute((&ctxt_key, &ctxt_iv, 0, num_aes_inputs)); + let fhe_results = cks.decrypt_u128_from_aes_ctr(&encrypted_result, num_aes_inputs); + assert_eq!(fhe_results, plain_results); + } +}