From f0f3dd76eba0bf455d462d77f56ce59617570696 Mon Sep 17 00:00:00 2001 From: Enzo Di Maria Date: Tue, 26 Aug 2025 16:07:54 +0200 Subject: [PATCH] feat(gpu): aes 128 --- backends/tfhe-cuda-backend/build.rs | 1 + .../tfhe-cuda-backend/cuda/include/aes/aes.h | 44 + .../cuda/include/aes/aes_utilities.h | 440 ++++++ .../tfhe-cuda-backend/cuda/src/aes/aes.cu | 88 ++ .../tfhe-cuda-backend/cuda/src/aes/aes.cuh | 1254 +++++++++++++++++ backends/tfhe-cuda-backend/src/bindings.rs | 72 + backends/tfhe-cuda-backend/wrapper.h | 1 + tfhe-benchmark/benches/integer/aes.rs | 154 ++ tfhe-benchmark/benches/integer/bench.rs | 3 + tfhe/src/integer/gpu/mod.rs | 286 ++++ tfhe/src/integer/gpu/server_key/radix/aes.rs | 465 ++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 1 + .../server_key/radix/tests_unsigned/mod.rs | 93 ++ .../radix/tests_unsigned/test_aes.rs | 59 + .../radix_parallel/tests_cases_unsigned.rs | 5 + .../radix_parallel/tests_unsigned/mod.rs | 1 + .../radix_parallel/tests_unsigned/test_aes.rs | 225 +++ 17 files changed, 3192 insertions(+) create mode 100644 backends/tfhe-cuda-backend/cuda/include/aes/aes.h create mode 100644 backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h create mode 100644 backends/tfhe-cuda-backend/cuda/src/aes/aes.cu create mode 100644 backends/tfhe-cuda-backend/cuda/src/aes/aes.cuh create mode 100644 tfhe-benchmark/benches/integer/aes.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/aes.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs diff --git a/backends/tfhe-cuda-backend/build.rs b/backends/tfhe-cuda-backend/build.rs index eefc7284b..bdcb3dcaa 100644 --- a/backends/tfhe-cuda-backend/build.rs +++ b/backends/tfhe-cuda-backend/build.rs @@ -84,6 +84,7 @@ fn main() { "cuda/include/ciphertext.h", "cuda/include/integer/compression/compression.h", "cuda/include/integer/integer.h", + "cuda/include/aes/aes.h", "cuda/include/zk/zk.h", "cuda/include/keyswitch/keyswitch.h", "cuda/include/keyswitch/ks_enums.h", diff --git a/backends/tfhe-cuda-backend/cuda/include/aes/aes.h b/backends/tfhe-cuda-backend/cuda/include/aes/aes.h new file mode 100644 index 000000000..a491e73fc --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/aes/aes.h @@ -0,0 +1,44 @@ +#ifndef AES_H +#define AES_H +#include "../integer/integer.h" + +extern "C" { +uint64_t scratch_cuda_integer_aes_encrypt_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_aes_inputs, + uint32_t sbox_parallelism); + +void cuda_integer_aes_ctr_encrypt_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, + CudaRadixCiphertextFFI const *round_keys, + const uint64_t *counter_bits_le_all_blocks, + uint32_t num_aes_inputs, int8_t *mem_ptr, + void *const *bsks, void *const *ksks); + +void cleanup_cuda_integer_aes_encrypt_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); + +uint64_t scratch_cuda_integer_key_expansion_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_integer_key_expansion_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_integer_key_expansion_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); +} + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h b/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h new file mode 100644 index 000000000..2ca5e9ed8 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/aes/aes_utilities.h @@ -0,0 +1,440 @@ +#ifndef AES_UTILITIES +#define AES_UTILITIES +#include "../integer/integer_utilities.h" + +/** + * This structure holds pre-computed LUTs for essential bitwise operations + * required by the homomorphic AES circuit. Pre-computing these tables allows + * for efficient application of non-linear functions like AND during the PBS + * process. It includes LUTs for: + * - AND: for the non-linear part of the S-Box. + * - FLUSH: to clear carry bits and isolate the message bit (x -> x & 1). + * - CARRY: to extract the carry bit for additions (x -> (x >> 1) & 1). + */ +template struct int_aes_lut_buffers { + int_radix_lut *and_lut; + int_radix_lut *flush_lut; + int_radix_lut *carry_lut; + + int_aes_lut_buffers(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint32_t sbox_parallelism, uint64_t &size_tracker) { + + constexpr uint32_t AES_STATE_BITS = 128; + constexpr uint32_t SBOX_MAX_AND_GATES = 18; + + this->and_lut = new int_radix_lut( + streams, params, 1, + SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism, + allocate_gpu_memory, size_tracker); + std::function and_lambda = + [](Torus a, Torus b) -> Torus { return a & b; }; + generate_device_accumulator_bivariate( + streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0), + this->and_lut->get_degree(0), this->and_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, and_lambda, allocate_gpu_memory); + auto active_streams_and_lut = streams.active_gpu_subset( + SBOX_MAX_AND_GATES * num_aes_inputs * sbox_parallelism); + this->and_lut->broadcast_lut(active_streams_and_lut); + + this->flush_lut = new int_radix_lut( + streams, params, 1, AES_STATE_BITS * num_aes_inputs, + allocate_gpu_memory, size_tracker); + std::function flush_lambda = [](Torus x) -> Torus { + return x & 1; + }; + generate_device_accumulator( + streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0), + this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, flush_lambda, allocate_gpu_memory); + auto active_streams_flush_lut = + streams.active_gpu_subset(AES_STATE_BITS * num_aes_inputs); + this->flush_lut->broadcast_lut(active_streams_flush_lut); + + this->carry_lut = new int_radix_lut( + streams, params, 1, num_aes_inputs, allocate_gpu_memory, size_tracker); + std::function carry_lambda = [](Torus x) -> Torus { + return (x >> 1) & 1; + }; + generate_device_accumulator( + streams.stream(0), streams.gpu_index(0), this->carry_lut->get_lut(0, 0), + this->carry_lut->get_degree(0), this->carry_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, carry_lambda, allocate_gpu_memory); + auto active_streams_carry_lut = streams.active_gpu_subset(num_aes_inputs); + this->carry_lut->broadcast_lut(active_streams_carry_lut); + } + + void release(CudaStreams streams) { + this->and_lut->release(streams); + delete this->and_lut; + this->and_lut = nullptr; + + this->flush_lut->release(streams); + delete this->flush_lut; + this->flush_lut = nullptr; + + this->carry_lut->release(streams); + delete this->carry_lut; + this->carry_lut = nullptr; + } +}; + +/** + * The operations within an AES round, particularly MixColumns, require + * intermediate storage for calculations. These buffers are designed to hold + * temporary values like copies of columns or the results of multiplications, + * avoiding overwriting data that is still needed in the same round. + */ +template struct int_aes_round_workspaces { + CudaRadixCiphertextFFI *mix_columns_col_copy_buffer; + CudaRadixCiphertextFFI *mix_columns_mul_workspace_buffer; + CudaRadixCiphertextFFI *vec_tmp_bit_buffer; + + int_aes_round_workspaces(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint64_t &size_tracker) { + + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t BYTES_PER_COLUMN = 4; + constexpr uint32_t BITS_PER_COLUMN = BITS_PER_BYTE * BYTES_PER_COLUMN; + constexpr uint32_t MIX_COLUMNS_MUL_WORKSPACE_BYTES = BYTES_PER_COLUMN + 1; + + this->mix_columns_col_copy_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), + this->mix_columns_col_copy_buffer, BITS_PER_COLUMN * num_aes_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->mix_columns_mul_workspace_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), + this->mix_columns_mul_workspace_buffer, + MIX_COLUMNS_MUL_WORKSPACE_BYTES * BITS_PER_BYTE * num_aes_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->vec_tmp_bit_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->vec_tmp_bit_buffer, + num_aes_inputs, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + } + + void release(CudaStreams streams, bool allocate_gpu_memory) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->mix_columns_col_copy_buffer, + allocate_gpu_memory); + delete this->mix_columns_col_copy_buffer; + this->mix_columns_col_copy_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->mix_columns_mul_workspace_buffer, + allocate_gpu_memory); + delete this->mix_columns_mul_workspace_buffer; + this->mix_columns_mul_workspace_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->vec_tmp_bit_buffer, + allocate_gpu_memory); + delete this->vec_tmp_bit_buffer; + this->vec_tmp_bit_buffer = nullptr; + } +}; + +/** + * In CTR mode, a counter is homomorphically added to the encrypted IV. This + * structure holds the necessary buffers for this 128-bit ripple-carry + * addition, such as the buffer for the propagating carry bit + * (`vec_tmp_carry_buffer`) across the addition chain. + */ +template struct int_aes_counter_workspaces { + CudaRadixCiphertextFFI *vec_tmp_carry_buffer; + CudaRadixCiphertextFFI *vec_tmp_sum_buffer; + CudaRadixCiphertextFFI *vec_trivial_b_bits_buffer; + Torus *h_counter_bits_buffer; + Torus *d_counter_bits_buffer; + + int_aes_counter_workspaces(CudaStreams streams, + const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint64_t &size_tracker) { + + this->vec_tmp_carry_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->vec_tmp_carry_buffer, + num_aes_inputs, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->vec_tmp_sum_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->vec_tmp_sum_buffer, + num_aes_inputs, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->vec_trivial_b_bits_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), + this->vec_trivial_b_bits_buffer, num_aes_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->h_counter_bits_buffer = + (Torus *)malloc(num_aes_inputs * sizeof(Torus)); + size_tracker += num_aes_inputs * sizeof(Torus); + this->d_counter_bits_buffer = (Torus *)cuda_malloc_with_size_tracking_async( + num_aes_inputs * sizeof(Torus), streams.stream(0), streams.gpu_index(0), + size_tracker, allocate_gpu_memory); + } + + void release(CudaStreams streams, bool allocate_gpu_memory) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->vec_tmp_carry_buffer, + allocate_gpu_memory); + delete this->vec_tmp_carry_buffer; + this->vec_tmp_carry_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->vec_tmp_sum_buffer, + allocate_gpu_memory); + delete this->vec_tmp_sum_buffer; + this->vec_tmp_sum_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->vec_trivial_b_bits_buffer, + allocate_gpu_memory); + delete this->vec_trivial_b_bits_buffer; + this->vec_trivial_b_bits_buffer = nullptr; + + free(this->h_counter_bits_buffer); + if (allocate_gpu_memory) { + cuda_drop_async(this->d_counter_bits_buffer, streams.stream(0), + streams.gpu_index(0)); + streams.synchronize(); + } + } +}; + +/** + * This structure allocates the most significant memory blocks: + * - `sbox_internal_workspace`: A large workspace for the complex, parallel + * evaluation of the S-Box circuit. + * - `main_bitsliced_states_buffer`: Holds the entire set of AES states in a + * bitsliced layout, which is optimal for parallel bitwise operations on the + * GPU. + * - Other buffers are used for data layout transformations (transposition) and + * for batching small operations into larger, more efficient launches. + */ +template struct int_aes_main_workspaces { + CudaRadixCiphertextFFI *sbox_internal_workspace; + CudaRadixCiphertextFFI *initial_states_and_jit_key_workspace; + CudaRadixCiphertextFFI *main_bitsliced_states_buffer; + CudaRadixCiphertextFFI *tmp_tiled_key_buffer; + CudaRadixCiphertextFFI *batch_processing_buffer; + + int_aes_main_workspaces(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint32_t sbox_parallelism, uint64_t &size_tracker) { + + constexpr uint32_t AES_STATE_BITS = 128; + constexpr uint32_t SBOX_MAX_AND_GATES = 18; + constexpr uint32_t BATCH_BUFFER_OPERANDS = 3; + + this->sbox_internal_workspace = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->sbox_internal_workspace, + num_aes_inputs * AES_STATE_BITS * sbox_parallelism, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->initial_states_and_jit_key_workspace = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), + this->initial_states_and_jit_key_workspace, + num_aes_inputs * AES_STATE_BITS, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->main_bitsliced_states_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), + this->main_bitsliced_states_buffer, num_aes_inputs * AES_STATE_BITS, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->tmp_tiled_key_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_tiled_key_buffer, + num_aes_inputs * AES_STATE_BITS, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->batch_processing_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->batch_processing_buffer, + num_aes_inputs * SBOX_MAX_AND_GATES * BATCH_BUFFER_OPERANDS * + sbox_parallelism, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + } + + void release(CudaStreams streams, bool allocate_gpu_memory) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->sbox_internal_workspace, + allocate_gpu_memory); + delete this->sbox_internal_workspace; + this->sbox_internal_workspace = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->initial_states_and_jit_key_workspace, + allocate_gpu_memory); + delete this->initial_states_and_jit_key_workspace; + this->initial_states_and_jit_key_workspace = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->main_bitsliced_states_buffer, + allocate_gpu_memory); + delete this->main_bitsliced_states_buffer; + this->main_bitsliced_states_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_tiled_key_buffer, + allocate_gpu_memory); + delete this->tmp_tiled_key_buffer; + this->tmp_tiled_key_buffer = nullptr; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->batch_processing_buffer, + allocate_gpu_memory); + delete this->batch_processing_buffer; + this->batch_processing_buffer = nullptr; + } +}; + +/** + * This structure acts as a container, holding instances of all the other buffer + * management structs. It provides a + * single object to manage the entire lifecycle of memory needed for a complete + * AES-CTR encryption operation. + */ +template struct int_aes_encrypt_buffer { + int_radix_params params; + bool allocate_gpu_memory; + uint32_t num_aes_inputs; + uint32_t sbox_parallel_instances; + + int_aes_lut_buffers *luts; + int_aes_round_workspaces *round_workspaces; + int_aes_counter_workspaces *counter_workspaces; + int_aes_main_workspaces *main_workspaces; + + int_aes_encrypt_buffer(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint32_t sbox_parallelism, uint64_t &size_tracker) { + + PANIC_IF_FALSE(num_aes_inputs >= 1, + "num_aes_inputs should be greater or equal to 1"); + + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->num_aes_inputs = num_aes_inputs; + this->sbox_parallel_instances = sbox_parallelism; + + this->luts = new int_aes_lut_buffers( + streams, params, allocate_gpu_memory, num_aes_inputs, sbox_parallelism, + size_tracker); + + this->round_workspaces = new int_aes_round_workspaces( + streams, params, allocate_gpu_memory, num_aes_inputs, size_tracker); + + this->counter_workspaces = new int_aes_counter_workspaces( + streams, params, allocate_gpu_memory, num_aes_inputs, size_tracker); + + this->main_workspaces = new int_aes_main_workspaces( + streams, params, allocate_gpu_memory, num_aes_inputs, sbox_parallelism, + size_tracker); + } + + void release(CudaStreams streams) { + luts->release(streams); + delete luts; + luts = nullptr; + + round_workspaces->release(streams, allocate_gpu_memory); + delete round_workspaces; + round_workspaces = nullptr; + + counter_workspaces->release(streams, allocate_gpu_memory); + delete counter_workspaces; + counter_workspaces = nullptr; + + main_workspaces->release(streams, allocate_gpu_memory); + delete main_workspaces; + main_workspaces = nullptr; + } +}; + +/** + * This structure holds the buffer for the 44 words of the expanded key + * and temporary storage for word manipulations. + * It contains its own instance of `int_aes_encrypt_buffer` because the + * key expansion algorithm itself requires using the S-Box. + * This separation ensures that memory for key expansion can be allocated and + * freed independently of the main encryption process. + */ +template struct int_key_expansion_buffer { + int_radix_params params; + bool allocate_gpu_memory; + + CudaRadixCiphertextFFI *words_buffer; + + CudaRadixCiphertextFFI *tmp_word_buffer; + CudaRadixCiphertextFFI *tmp_rotated_word_buffer; + + int_aes_encrypt_buffer *aes_encrypt_buffer; + + int_key_expansion_buffer(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + + constexpr uint32_t TOTAL_WORDS = 44; + constexpr uint32_t BITS_PER_WORD = 32; + constexpr uint32_t TOTAL_BITS = TOTAL_WORDS * BITS_PER_WORD; + + this->words_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->words_buffer, TOTAL_BITS, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->tmp_word_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_word_buffer, + BITS_PER_WORD, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->tmp_rotated_word_buffer = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_rotated_word_buffer, + BITS_PER_WORD, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->aes_encrypt_buffer = new int_aes_encrypt_buffer( + streams, params, allocate_gpu_memory, 1, 4, size_tracker); + } + + void release(CudaStreams streams) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->words_buffer, allocate_gpu_memory); + delete this->words_buffer; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_word_buffer, allocate_gpu_memory); + delete this->tmp_word_buffer; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_rotated_word_buffer, + allocate_gpu_memory); + delete this->tmp_rotated_word_buffer; + + this->aes_encrypt_buffer->release(streams); + delete this->aes_encrypt_buffer; + } +}; + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/src/aes/aes.cu b/backends/tfhe-cuda-backend/cuda/src/aes/aes.cu new file mode 100644 index 000000000..b64ba4557 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/aes/aes.cu @@ -0,0 +1,88 @@ +#include "../../include/aes/aes.h" +#include "aes.cuh" + +uint64_t scratch_cuda_integer_aes_encrypt_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_aes_inputs, + uint32_t sbox_parallelism) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_integer_aes_encrypt( + CudaStreams(streams), (int_aes_encrypt_buffer **)mem_ptr, + params, allocate_gpu_memory, num_aes_inputs, sbox_parallelism); +} + +void cuda_integer_aes_ctr_encrypt_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, + CudaRadixCiphertextFFI const *round_keys, + const uint64_t *counter_bits_le_all_blocks, + uint32_t num_aes_inputs, int8_t *mem_ptr, + void *const *bsks, void *const *ksks) { + + host_integer_aes_ctr_encrypt( + CudaStreams(streams), output, iv, round_keys, counter_bits_le_all_blocks, + num_aes_inputs, (int_aes_encrypt_buffer *)mem_ptr, bsks, + (uint64_t **)ksks); +} + +void cleanup_cuda_integer_aes_encrypt_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + + int_aes_encrypt_buffer *mem_ptr = + (int_aes_encrypt_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + + delete mem_ptr; + *mem_ptr_void = nullptr; +} + +uint64_t scratch_cuda_integer_key_expansion_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_integer_key_expansion( + CudaStreams(streams), (int_key_expansion_buffer **)mem_ptr, + params, allocate_gpu_memory); +} + +void cuda_integer_key_expansion_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + + host_integer_key_expansion( + CudaStreams(streams), expanded_keys, key, + (int_key_expansion_buffer *)mem_ptr, bsks, (uint64_t **)ksks); +} + +void cleanup_cuda_integer_key_expansion_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + int_key_expansion_buffer *mem_ptr = + (int_key_expansion_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/aes/aes.cuh b/backends/tfhe-cuda-backend/cuda/src/aes/aes.cuh new file mode 100644 index 000000000..0365acaa4 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/aes/aes.cuh @@ -0,0 +1,1254 @@ +#ifndef AES_CUH +#define AES_CUH + +#include "../../include/aes/aes_utilities.h" +#include "../integer/integer.cuh" +#include "../integer/radix_ciphertext.cuh" +#include "../integer/scalar_addition.cuh" +#include "../linearalgebra/addition.cuh" + +template +uint64_t scratch_cuda_integer_aes_encrypt( + CudaStreams streams, int_aes_encrypt_buffer **mem_ptr, + int_radix_params params, bool allocate_gpu_memory, uint32_t num_aes_inputs, + uint32_t sbox_parallelism) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_aes_encrypt_buffer( + streams, params, allocate_gpu_memory, num_aes_inputs, sbox_parallelism, + size_tracker); + return size_tracker; +} + +/** + * Transposes a collection of AES states from a block-oriented layout to a + * bit-sliced layout. This is a crucial data restructuring step for efficient + * homomorphic bitwise operations. + * + * Source (Block-oriented) Destination (Bitsliced) + * Block 0: [B0b0, B0b1, B0b2, ...] Slice 0: [B0b0, B1b0, B2b0, ...] + * Block 1: [B1b0, B1b1, B1b2, ...] -----> Slice 1: [B0b1, B1b1, B2b1, ...] + * Block 2: [B2b0, B2b1, B2b2, ...] Slice 2: [B0b2, B1b2, B2b2, ...] + * ... ... + * + */ +template +__host__ void +transpose_blocks_to_bitsliced(cudaStream_t stream, uint32_t gpu_index, + CudaRadixCiphertextFFI *dest_bitsliced, + const CudaRadixCiphertextFFI *source_blocks, + uint32_t num_aes_inputs, + uint32_t block_size_bits) { + + PANIC_IF_FALSE(dest_bitsliced != source_blocks, + "transpose_blocks_to_bitsliced is not an in-place function."); + + for (uint32_t i = 0; i < block_size_bits; ++i) { + for (uint32_t j = 0; j < num_aes_inputs; ++j) { + uint32_t src_idx = j * block_size_bits + i; + uint32_t dest_idx = i * num_aes_inputs + j; + copy_radix_ciphertext_slice_async( + stream, gpu_index, dest_bitsliced, dest_idx, dest_idx + 1, + source_blocks, src_idx, src_idx + 1); + } + } +} + +/** + * Transposes a collection of AES states from a bit-sliced layout back to a + * block-oriented layout. This is the inverse of + * 'transpose_blocks_to_bitsliced'. + * + */ +template +__host__ void +transpose_bitsliced_to_blocks(cudaStream_t stream, uint32_t gpu_index, + CudaRadixCiphertextFFI *dest_blocks, + const CudaRadixCiphertextFFI *source_bitsliced, + uint32_t num_aes_inputs, + uint32_t block_size_bits) { + + PANIC_IF_FALSE(dest_blocks != source_bitsliced, + "transpose_bitsliced_to_blocks is not an in-place function."); + + for (uint32_t i = 0; i < block_size_bits; ++i) { + for (uint32_t j = 0; j < num_aes_inputs; ++j) { + uint32_t src_idx = i * num_aes_inputs + j; + uint32_t dest_idx = j * block_size_bits + i; + copy_radix_ciphertext_slice_async( + stream, gpu_index, dest_blocks, dest_idx, dest_idx + 1, + source_bitsliced, src_idx, src_idx + 1); + } + } +} + +/** + * Performs a vectorized homomorphic XOR operation on two sets of ciphertexts. + * + */ +template +__host__ __forceinline__ void +aes_xor(CudaStreams streams, int_aes_encrypt_buffer *mem, + CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs, + const CudaRadixCiphertextFFI *rhs) { + + host_addition(streams.stream(0), streams.gpu_index(0), out, lhs, rhs, + out->num_radix_blocks, mem->params.message_modulus, + mem->params.carry_modulus); +} + +/** + * Applies a "flush" Look-Up Table (LUT) to a vector of ciphertexts. + * This operation isolates the first message bit (the LSB) by applying the + * identity function to it, while discarding any higher-order bits + * that may have resulted from previous additions. This effectively cleans the + * result. + * + */ +template +__host__ __forceinline__ void +aes_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI *data, + int_aes_encrypt_buffer *mem, void *const *bsks, + Torus *const *ksks) { + + integer_radix_apply_univariate_lookup_table_kb( + streams, data, data, bsks, ksks, mem->luts->flush_lut, + data->num_radix_blocks); +} + +/** + * Performs an operation: homomorphically adds a plaintext 1 to a + * ciphertext, then flushes the result to ensure it's a valid bit. + * + */ +template +__host__ __forceinline__ void aes_scalar_add_one_flush_inplace( + CudaStreams streams, CudaRadixCiphertextFFI *data, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + host_integer_radix_add_scalar_one_inplace( + streams, data, mem->params.message_modulus, mem->params.carry_modulus); + + aes_flush_inplace(streams, data, mem, bsks, ksks); +} + +/** + * Batches multiple "flush" operations into a single operation. + * This is done in three steps: + * 1. GATHER: All target ciphertexts are copied into one large, contiguous + * buffer. + * 2. PROCESS: A single flush operation is executed on the entire buffer. + * 3. SCATTER: The results are copied from the buffer back to the original + * ciphertext locations. + * + */ +template +__host__ void +batch_vec_flush_inplace(CudaStreams streams, CudaRadixCiphertextFFI **targets, + size_t count, int_aes_encrypt_buffer *mem, + void *const *bsks, Torus *const *ksks) { + + uint32_t num_radix_blocks = targets[0]->num_radix_blocks; + + CudaRadixCiphertextFFI batch_in, batch_out; + as_radix_ciphertext_slice( + &batch_in, mem->main_workspaces->batch_processing_buffer, 0, + count * num_radix_blocks); + as_radix_ciphertext_slice( + &batch_out, mem->main_workspaces->batch_processing_buffer, + count * num_radix_blocks, (2 * count) * num_radix_blocks); + + for (size_t i = 0; i < count; ++i) { + CudaRadixCiphertextFFI dest_slice; + as_radix_ciphertext_slice(&dest_slice, &batch_in, + i * num_radix_blocks, + (i + 1) * num_radix_blocks); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_slice, targets[i]); + } + + integer_radix_apply_univariate_lookup_table_kb( + streams, &batch_out, &batch_in, bsks, ksks, mem->luts->flush_lut, + batch_out.num_radix_blocks); + + for (size_t i = 0; i < count; ++i) { + CudaRadixCiphertextFFI src_slice; + as_radix_ciphertext_slice(&src_slice, &batch_out, + i * num_radix_blocks, + (i + 1) * num_radix_blocks); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + targets[i], &src_slice); + } +} + +/** + * Batches multiple "and" operations into a single, large launch. + * + */ +template +__host__ void batch_vec_and_inplace(CudaStreams streams, + CudaRadixCiphertextFFI **outs, + CudaRadixCiphertextFFI **lhs, + CudaRadixCiphertextFFI **rhs, size_t count, + int_aes_encrypt_buffer *mem, + void *const *bsks, Torus *const *ksks) { + + uint32_t num_aes_inputs = outs[0]->num_radix_blocks; + + CudaRadixCiphertextFFI batch_lhs, batch_rhs, batch_out; + as_radix_ciphertext_slice( + &batch_lhs, mem->main_workspaces->batch_processing_buffer, 0, + count * num_aes_inputs); + as_radix_ciphertext_slice( + &batch_rhs, mem->main_workspaces->batch_processing_buffer, + count * num_aes_inputs, (2 * count) * num_aes_inputs); + as_radix_ciphertext_slice( + &batch_out, mem->main_workspaces->batch_processing_buffer, + (2 * count) * num_aes_inputs, (3 * count) * num_aes_inputs); + + for (size_t i = 0; i < count; ++i) { + CudaRadixCiphertextFFI dest_lhs_slice, dest_rhs_slice; + as_radix_ciphertext_slice(&dest_lhs_slice, &batch_lhs, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + as_radix_ciphertext_slice(&dest_rhs_slice, &batch_rhs, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_lhs_slice, lhs[i]); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_rhs_slice, rhs[i]); + } + + integer_radix_apply_bivariate_lookup_table_kb( + streams, &batch_out, &batch_lhs, &batch_rhs, bsks, ksks, + mem->luts->and_lut, batch_out.num_radix_blocks, + mem->params.message_modulus); + + for (size_t i = 0; i < count; ++i) { + CudaRadixCiphertextFFI src_slice; + as_radix_ciphertext_slice(&src_slice, &batch_out, i * num_aes_inputs, + (i + 1) * num_aes_inputs); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + outs[i], &src_slice); + } +} + +/** + * Implements the AES S-Box substitution for two bytes in parallel using a + * bitsliced circuit design. + * + * Boyar-Peralta circuit: https://eprint.iacr.org/2011/332.pdf + * + * sbox_io_bytes (Input: Array of pointers to separate bytes) + * [ptr] -> [B0b0, B0b1, B0b2, B0b3, B0b4, B0b5, B0b6, B0b7] + * [ptr] -> [B1b0, B1b1, B1b2, B1b3, B1b4, B1b5, B1b6, B1b7] + * [ptr] -> [B2b0, B2b1, B2b2, B2b3, B2b4, B2b5, B2b6, B2b7] + * ... + * | + * | GATHER + * V + * Internal Bitsliced Buffer (input_bits) + * Slice 0: [B0b0, B1b0, B2b0, ...] (All the 0th bits) + * Slice 1: [B0b1, B1b1, B2b1, ...] (All the 1st bits) + * Slice 2: [B0b2, B1b2, B2b2, ...] (All the 2nd bits) + * ... + * | + * V + * +----------------------------------+ + * | Homomorphic S-Box Evaluation | + * +----------------------------------+ + * | + * V + * Internal Bitsliced Buffer (output_bits) + * Result Slice 0: [R0b0, R1b0, R2b0, ...] + * Result Slice 1: [R0b1, R1b1, R2b1, ...] + * Result Slice 2: [R0b2, R1b2, R2b2, ...] + * ... + * | + * | SCATTER + * V + * sbox_io_bytes (Output: Results written back in-place) + * [ptr] -> [R0b0, R0b1, R0b2, R0b3, R0b4, R0b5, R0b6, R0b7] + * [ptr] -> [R1b0, R1b1, R1b2, R1b3, R1b4, R1b5, R1b6, R1b7] + * [ptr] -> [R2b0, R2b1, R2b2, R2b3, R2b4, R2b5, R2b6, R2b7] + * ... + */ +template +__host__ void vectorized_sbox_n_bytes(CudaStreams streams, + CudaRadixCiphertextFFI **sbox_io_bytes, + uint32_t num_bytes_parallel, + uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, + void *const *bsks, Torus *const *ksks) { + + uint32_t num_sbox_blocks = num_bytes_parallel * num_aes_inputs; + + constexpr uint32_t INPUT_BITS_LEN = 8; + constexpr uint32_t OUTPUT_BITS_LEN = 8; + constexpr uint32_t WIRES_A_LEN = 22; + constexpr uint32_t WIRES_B_LEN = 68; + constexpr uint32_t WIRES_C_LEN = 18; + + CudaRadixCiphertextFFI wires_a[WIRES_A_LEN], wires_b[WIRES_B_LEN], + wires_c[WIRES_C_LEN]; + + for (uint32_t i = 0; i < WIRES_A_LEN; ++i) + as_radix_ciphertext_slice( + &wires_a[i], mem->main_workspaces->sbox_internal_workspace, + i * num_sbox_blocks, (i + 1) * num_sbox_blocks); + for (uint32_t i = 0; i < WIRES_B_LEN; ++i) + as_radix_ciphertext_slice( + &wires_b[i], mem->main_workspaces->sbox_internal_workspace, + (WIRES_A_LEN + i) * num_sbox_blocks, + (WIRES_A_LEN + i + 1) * num_sbox_blocks); + for (uint32_t i = 0; i < WIRES_C_LEN; ++i) + as_radix_ciphertext_slice( + &wires_c[i], mem->main_workspaces->sbox_internal_workspace, + (WIRES_A_LEN + WIRES_B_LEN + i) * num_sbox_blocks, + (WIRES_A_LEN + WIRES_B_LEN + i + 1) * num_sbox_blocks); + + // Input Reordering (Gather) + // + + CudaRadixCiphertextFFI input_bits[INPUT_BITS_LEN]; + CudaRadixCiphertextFFI *reordered_input_buffer = + mem->main_workspaces->tmp_tiled_key_buffer; + + for (uint32_t bit = 0; bit < INPUT_BITS_LEN; ++bit) { + as_radix_ciphertext_slice(&input_bits[bit], reordered_input_buffer, + bit * num_sbox_blocks, + (bit + 1) * num_sbox_blocks); + + for (uint32_t byte_idx = 0; byte_idx < num_bytes_parallel; ++byte_idx) { + CudaRadixCiphertextFFI *current_source_byte = sbox_io_bytes[byte_idx]; + CudaRadixCiphertextFFI dest_slice; + as_radix_ciphertext_slice(&dest_slice, &input_bits[bit], + byte_idx * num_aes_inputs, + (byte_idx + 1) * num_aes_inputs); + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &dest_slice, + ¤t_source_byte[bit]); + } + } + +#define XOR(out, a, b) \ + do { \ + aes_xor(streams, mem, out, a, b); \ + } while (0) + +#define FLUSH(...) \ + do { \ + CudaRadixCiphertextFFI *targets[] = {__VA_ARGS__}; \ + batch_vec_flush_inplace(streams, targets, \ + sizeof(targets) / sizeof(targets[0]), mem, bsks, \ + ksks); \ + } while (0) + +#define AND(outs, lhs, rhs) \ + do { \ + batch_vec_and_inplace(streams, outs, lhs, rhs, \ + sizeof(outs) / sizeof(outs[0]), mem, bsks, ksks); \ + } while (0) + +#define ADD_ONE_FLUSH(target) \ + do { \ + aes_scalar_add_one_flush_inplace(streams, target, mem, bsks, ksks); \ + } while (0) + +#define ADD_ONE(target) \ + do { \ + host_integer_radix_add_scalar_one_inplace( \ + streams, target, mem->params.message_modulus, \ + mem->params.carry_modulus); \ + } while (0) + + // Homomorphic S-Box Circuit Evaluation + // + + XOR(&wires_a[14], &input_bits[3], &input_bits[5]); + XOR(&wires_a[13], &input_bits[0], &input_bits[6]); + XOR(&wires_a[9], &input_bits[0], &input_bits[3]); + XOR(&wires_a[8], &input_bits[0], &input_bits[5]); + XOR(&wires_b[0], &input_bits[1], &input_bits[2]); + FLUSH(&wires_a[14], &wires_a[13], &wires_a[9], &wires_a[8]); + XOR(&wires_a[1], &wires_b[0], &input_bits[7]); + FLUSH(&wires_a[1]); + XOR(&wires_a[12], &wires_a[13], &wires_a[14]); + XOR(&wires_a[4], &wires_a[1], &input_bits[3]); + XOR(&wires_a[2], &wires_a[1], &input_bits[0]); + XOR(&wires_a[5], &wires_a[1], &input_bits[6]); + FLUSH(&wires_a[12], &wires_a[4], &wires_a[2], &wires_a[5]); + XOR(&wires_a[3], &wires_a[5], &wires_a[8]); + XOR(&wires_b[1], &input_bits[4], &wires_a[12]); + FLUSH(&wires_a[3]); + XOR(&wires_a[15], &wires_b[1], &input_bits[5]); + XOR(&wires_a[20], &wires_b[1], &input_bits[1]); + FLUSH(&wires_a[15], &wires_a[20]); + XOR(&wires_a[6], &wires_a[15], &input_bits[7]); + XOR(&wires_a[10], &wires_a[15], &wires_b[0]); + XOR(&wires_a[11], &wires_a[20], &wires_a[9]); + FLUSH(&wires_a[6], &wires_a[10]); + XOR(&wires_a[7], &input_bits[7], &wires_a[11]); + FLUSH(&wires_a[7]); + XOR(&wires_a[17], &wires_a[10], &wires_a[11]); + XOR(&wires_a[19], &wires_a[10], &wires_a[8]); + XOR(&wires_a[16], &wires_b[0], &wires_a[11]); + FLUSH(&wires_a[17], &wires_a[19], &wires_a[16]); + XOR(&wires_a[21], &wires_a[13], &wires_a[16]); + XOR(&wires_a[18], &input_bits[0], &wires_a[16]); + + CudaRadixCiphertextFFI *and_outs_1[] = { + &wires_b[2], &wires_b[3], &wires_b[5], &wires_b[7], &wires_b[8], + &wires_b[10], &wires_b[12], &wires_b[13], &wires_b[15]}; + CudaRadixCiphertextFFI *and_lhs_1[] = { + &wires_a[15], &wires_a[3], &input_bits[7], &wires_a[13], &wires_a[1], + &wires_a[2], &wires_a[9], &wires_a[14], &wires_a[8]}; + CudaRadixCiphertextFFI *and_rhs_1[] = { + &wires_a[12], &wires_a[6], &wires_a[4], &wires_a[16], &wires_a[5], + &wires_a[7], &wires_a[11], &wires_a[17], &wires_a[10]}; + AND(and_outs_1, and_lhs_1, and_rhs_1); + + FLUSH(&wires_a[21], &wires_a[18]); + XOR(&wires_b[4], &wires_b[3], &wires_b[2]); + XOR(&wires_b[6], &wires_b[5], &wires_b[2]); + XOR(&wires_b[9], &wires_b[8], &wires_b[7]); + XOR(&wires_b[11], &wires_b[10], &wires_b[7]); + XOR(&wires_b[14], &wires_b[13], &wires_b[12]); + XOR(&wires_b[16], &wires_b[15], &wires_b[12]); + XOR(&wires_b[17], &wires_b[4], &wires_b[14]); + XOR(&wires_b[18], &wires_b[6], &wires_b[16]); + XOR(&wires_b[19], &wires_b[9], &wires_b[14]); + XOR(&wires_b[20], &wires_b[11], &wires_b[16]); + XOR(&wires_b[21], &wires_b[17], &wires_a[20]); + XOR(&wires_b[22], &wires_b[18], &wires_a[19]); + XOR(&wires_b[23], &wires_b[19], &wires_a[21]); + XOR(&wires_b[24], &wires_b[20], &wires_a[18]); + FLUSH(&wires_b[21], &wires_b[23], &wires_b[24]); + XOR(&wires_b[25], &wires_b[21], &wires_b[22]); + FLUSH(&wires_b[25]); + + CudaRadixCiphertextFFI *and_outs_2[] = {&wires_b[26]}; + CudaRadixCiphertextFFI *and_lhs_2[] = {&wires_b[21]}; + CudaRadixCiphertextFFI *and_rhs_2[] = {&wires_b[23]}; + AND(and_outs_2, and_lhs_2, and_rhs_2); + + XOR(&wires_b[27], &wires_b[24], &wires_b[26]); + XOR(&wires_b[30], &wires_b[23], &wires_b[24]); + XOR(&wires_b[31], &wires_b[22], &wires_b[26]); + FLUSH(&wires_b[27], &wires_b[30], &wires_b[31]); + + CudaRadixCiphertextFFI *and_outs_3[] = {&wires_b[28]}; + CudaRadixCiphertextFFI *and_lhs_3[] = {&wires_b[25]}; + CudaRadixCiphertextFFI *and_rhs_3[] = {&wires_b[27]}; + AND(and_outs_3, and_lhs_3, and_rhs_3); + + XOR(&wires_b[29], &wires_b[28], &wires_b[22]); + + CudaRadixCiphertextFFI *and_outs_4[] = {&wires_b[32]}; + CudaRadixCiphertextFFI *and_lhs_4[] = {&wires_b[30]}; + CudaRadixCiphertextFFI *and_rhs_4[] = {&wires_b[31]}; + AND(and_outs_4, and_lhs_4, and_rhs_4); + + FLUSH(&wires_b[29]); + XOR(&wires_b[33], &wires_b[32], &wires_b[24]); + FLUSH(&wires_b[33]); + XOR(&wires_b[42], &wires_b[29], &wires_b[33]); + FLUSH(&wires_b[42]); + XOR(&wires_b[34], &wires_b[23], &wires_b[33]); + XOR(&wires_b[35], &wires_b[27], &wires_b[33]); + FLUSH(&wires_b[34], &wires_b[35]); + + CudaRadixCiphertextFFI *and_outs_5[] = {&wires_b[36]}; + CudaRadixCiphertextFFI *and_lhs_5[] = {&wires_b[24]}; + CudaRadixCiphertextFFI *and_rhs_5[] = {&wires_b[35]}; + AND(and_outs_5, and_lhs_5, and_rhs_5); + + XOR(&wires_b[37], &wires_b[36], &wires_b[34]); + XOR(&wires_b[38], &wires_b[27], &wires_b[36]); + FLUSH(&wires_b[38]); + XOR(&wires_b[44], &wires_b[33], &wires_b[37]); + + CudaRadixCiphertextFFI *and_outs_6[] = {&wires_b[39]}; + CudaRadixCiphertextFFI *and_lhs_6[] = {&wires_b[38]}; + CudaRadixCiphertextFFI *and_rhs_6[] = {&wires_b[29]}; + AND(and_outs_6, and_lhs_6, and_rhs_6); + + XOR(&wires_b[40], &wires_b[25], &wires_b[39]); + XOR(&wires_b[41], &wires_b[40], &wires_b[37]); + XOR(&wires_b[43], &wires_b[29], &wires_b[40]); + FLUSH(&wires_b[41]); + XOR(&wires_b[45], &wires_b[42], &wires_b[41]); + FLUSH(&wires_b[45]); + + CudaRadixCiphertextFFI *and_outs_7[] = { + &wires_c[0], &wires_c[1], &wires_c[2], &wires_c[3], &wires_c[4], + &wires_c[5], &wires_c[6], &wires_c[7], &wires_c[8], &wires_c[9], + &wires_c[10], &wires_c[11], &wires_c[12], &wires_c[13], &wires_c[14], + &wires_c[15], &wires_c[16], &wires_c[17]}; + CudaRadixCiphertextFFI *and_lhs_7[] = { + &wires_a[15], &wires_a[6], &wires_b[33], &wires_a[16], &wires_a[1], + &wires_b[29], &wires_b[42], &wires_a[17], &wires_a[10], &wires_a[12], + &wires_a[3], &wires_b[33], &wires_a[13], &wires_a[5], &wires_b[29], + &wires_b[42], &wires_b[45], &wires_b[41]}; + CudaRadixCiphertextFFI *and_rhs_7[] = { + &wires_b[44], &wires_b[37], &input_bits[7], &wires_b[43], &wires_b[40], + &wires_a[7], &wires_a[11], &wires_b[45], &wires_b[41], &wires_b[44], + &wires_b[37], &wires_a[4], &wires_b[43], &wires_b[40], &wires_a[2], + &wires_a[9], &wires_a[14], &wires_a[8]}; + AND(and_outs_7, and_lhs_7, and_rhs_7); + + XOR(&wires_b[46], &wires_c[15], &wires_c[16]); + XOR(&wires_b[47], &wires_c[10], &wires_c[11]); + XOR(&wires_b[48], &wires_c[5], &wires_c[13]); + XOR(&wires_b[49], &wires_c[9], &wires_c[10]); + XOR(&wires_b[50], &wires_c[2], &wires_c[12]); + XOR(&wires_b[51], &wires_c[2], &wires_c[5]); + XOR(&wires_b[52], &wires_c[7], &wires_c[8]); + XOR(&wires_b[53], &wires_c[0], &wires_c[3]); + XOR(&wires_b[54], &wires_c[6], &wires_c[7]); + XOR(&wires_b[55], &wires_c[16], &wires_c[17]); + XOR(&wires_b[56], &wires_c[12], &wires_b[48]); + XOR(&wires_b[57], &wires_b[50], &wires_b[53]); + XOR(&wires_b[58], &wires_c[4], &wires_b[46]); + XOR(&wires_b[59], &wires_c[3], &wires_b[54]); + XOR(&wires_b[60], &wires_b[46], &wires_b[57]); + XOR(&wires_b[61], &wires_c[14], &wires_b[57]); + XOR(&wires_b[62], &wires_b[52], &wires_b[58]); + XOR(&wires_b[63], &wires_b[49], &wires_b[58]); + XOR(&wires_b[64], &wires_c[4], &wires_b[59]); + FLUSH(&wires_b[61], &wires_b[63], &wires_b[64]); + XOR(&wires_b[65], &wires_b[61], &wires_b[62]); + FLUSH(&wires_b[65]); + XOR(&wires_b[66], &wires_c[1], &wires_b[63]); + FLUSH(&wires_b[66]); + + // Final Output Combination + // + + CudaRadixCiphertextFFI output_bits[OUTPUT_BITS_LEN]; + for (uint32_t i = 0; i < OUTPUT_BITS_LEN; i++) + as_radix_ciphertext_slice( + &output_bits[i], mem->main_workspaces->sbox_internal_workspace, + i * num_sbox_blocks, (i + 1) * num_sbox_blocks); + + CudaRadixCiphertextFFI single_bit_buffer; + as_radix_ciphertext_slice( + &single_bit_buffer, mem->main_workspaces->sbox_internal_workspace, + (OUTPUT_BITS_LEN * num_sbox_blocks), + (OUTPUT_BITS_LEN * num_sbox_blocks) + num_sbox_blocks); + + XOR(&output_bits[0], &wires_b[59], &wires_b[63]); + XOR(&wires_b[67], &wires_b[64], &wires_b[65]); + XOR(&output_bits[3], &wires_b[53], &wires_b[66]); + XOR(&output_bits[4], &wires_b[51], &wires_b[66]); + XOR(&output_bits[5], &wires_b[47], &wires_b[65]); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &single_bit_buffer, &wires_b[62]); + + ADD_ONE_FLUSH(&single_bit_buffer); + XOR(&output_bits[6], &wires_b[56], &single_bit_buffer); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &single_bit_buffer, &wires_b[60]); + + ADD_ONE_FLUSH(&single_bit_buffer); + XOR(&output_bits[7], &wires_b[48], &single_bit_buffer); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &single_bit_buffer, &output_bits[3]); + + ADD_ONE(&single_bit_buffer); + XOR(&output_bits[1], &wires_b[64], &single_bit_buffer); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &single_bit_buffer, &wires_b[67]); + + ADD_ONE_FLUSH(&single_bit_buffer); + XOR(&output_bits[2], &wires_b[55], &single_bit_buffer); + + FLUSH(&output_bits[0], &output_bits[1], &output_bits[2], &output_bits[3], + &output_bits[4], &output_bits[5], &output_bits[6], &output_bits[7]); + + // Output Reordering (Scatter) + // + + for (uint32_t bit = 0; bit < OUTPUT_BITS_LEN; ++bit) { + for (uint32_t byte_idx = 0; byte_idx < num_bytes_parallel; ++byte_idx) { + CudaRadixCiphertextFFI *current_dest_byte = sbox_io_bytes[byte_idx]; + CudaRadixCiphertextFFI src_slice; + as_radix_ciphertext_slice(&src_slice, &output_bits[bit], + byte_idx * num_aes_inputs, + (byte_idx + 1) * num_aes_inputs); + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), + ¤t_dest_byte[bit], &src_slice); + } + } + +#undef XOR +#undef FLUSH +#undef AND +#undef ADD_ONE_FLUSH +} + +/** + * Implements the ShiftRows step of AES on bitsliced data. + * + * Before ShiftRows (Input State): + * +----+----+----+----+ + * | A | B | C | D | + * +----+----+----+----+ + * | E | F | G | H | + * +----+----+----+----+ + * | I | J | K | L | + * +----+----+----+----+ + * | M | N | O | P | + * +----+----+----+----+ + * + * After ShiftRows (Output State): + * +----+----+----+----+ + * | A | B | C | D | <- No shift + * +----+----+----+----+ + * | F | G | H | E | <- 1 byte left shift + * +----+----+----+----+ + * | K | L | I | J | <- 2 bytes left shift + * +----+----+----+----+ + * | P | M | N | O | <- 3 bytes left shift + * +----+----+----+----+ + * + * + */ +template +__host__ void vectorized_shift_rows(CudaStreams streams, + CudaRadixCiphertextFFI *state_bitsliced, + uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem) { + constexpr uint32_t NUM_BYTES = 16; + constexpr uint32_t LEN_BYTE = 8; + constexpr uint32_t NUM_BITS = NUM_BYTES * LEN_BYTE; + + CudaRadixCiphertextFFI tmp_full_state_bitsliced_slice; + as_radix_ciphertext_slice( + &tmp_full_state_bitsliced_slice, + mem->main_workspaces->sbox_internal_workspace, 0, + state_bitsliced->num_radix_blocks); + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &tmp_full_state_bitsliced_slice, + state_bitsliced); + + CudaRadixCiphertextFFI s_bits[NUM_BITS]; + for (int i = 0; i < NUM_BITS; i++) { + as_radix_ciphertext_slice(&s_bits[i], state_bitsliced, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + } + + CudaRadixCiphertextFFI tmp_s_bits_slices[NUM_BITS]; + for (int i = 0; i < NUM_BITS; i++) { + as_radix_ciphertext_slice( + &tmp_s_bits_slices[i], &tmp_full_state_bitsliced_slice, + i * num_aes_inputs, (i + 1) * num_aes_inputs); + } + + const int shift_rows_map[] = {0, 5, 10, 15, 4, 9, 14, 3, + 8, 13, 2, 7, 12, 1, 6, 11}; + + for (int i = 0; i < NUM_BYTES; i++) { + for (int bit = 0; bit < LEN_BYTE; bit++) { + CudaRadixCiphertextFFI *dest_slice = &s_bits[i * LEN_BYTE + bit]; + CudaRadixCiphertextFFI *src_slice = + &tmp_s_bits_slices[shift_rows_map[i] * LEN_BYTE + bit]; + copy_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), dest_slice, src_slice); + } + } +} + +/** + * Helper for MixColumns. Homomorphically multiplies an 8-bit byte by 2. + * + */ +template +__host__ void vectorized_mul_by_2(CudaStreams streams, + CudaRadixCiphertextFFI *res_byte, + CudaRadixCiphertextFFI *in_byte, + int_aes_encrypt_buffer *mem) { + + constexpr uint32_t LEN_BYTE = 8; + + CudaRadixCiphertextFFI *msb = &in_byte[0]; + + for (int i = 0; i < LEN_BYTE - 1; ++i) { + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &res_byte[i], &in_byte[i + 1]); + } + + set_zero_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), &res_byte[LEN_BYTE - 1], 0, + res_byte[LEN_BYTE - 1].num_radix_blocks); + + const int indices_to_xor[] = {3, 4, 6, 7}; + for (int index : indices_to_xor) { + aes_xor(streams, mem, &res_byte[index], &res_byte[index], msb); + } +} + +/** + * Implements the MixColumns step of AES. It performs a matrix multiplication + * on each column of the AES state. + * + * [ s'_0 ] [ 02 03 01 01 ] [ s_0 ] + * [ s'_1 ] = [ 01 02 03 01 ] * [ s_1 ] + * [ s'_2 ] [ 01 01 02 03 ] [ s_2 ] + * [ s'_3 ] [ 03 01 01 02 ] [ s_3 ] + * + */ +template +__host__ void vectorized_mix_columns(CudaStreams streams, + CudaRadixCiphertextFFI *s_bits, + uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, + void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t BYTES_PER_COLUMN = 4; + constexpr uint32_t NUM_COLUMNS = 4; + constexpr uint32_t BITS_PER_COLUMN = BYTES_PER_COLUMN * BITS_PER_BYTE; + + for (uint32_t col = 0; col < NUM_COLUMNS; ++col) { + CudaRadixCiphertextFFI *col_copy_buffer = + mem->round_workspaces->mix_columns_col_copy_buffer; + for (uint32_t i = 0; i < BITS_PER_COLUMN; ++i) { + CudaRadixCiphertextFFI dest_slice, src_slice; + as_radix_ciphertext_slice(&dest_slice, col_copy_buffer, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + as_radix_ciphertext_slice( + &src_slice, s_bits, (col * BITS_PER_COLUMN + i) * num_aes_inputs, + (col * BITS_PER_COLUMN + i + 1) * num_aes_inputs); + copy_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), &dest_slice, &src_slice); + } + + CudaRadixCiphertextFFI b_orig[BYTES_PER_COLUMN][BITS_PER_BYTE]; + for (uint32_t i = 0; i < BYTES_PER_COLUMN; ++i) { + for (uint32_t j = 0; j < BITS_PER_BYTE; j++) { + as_radix_ciphertext_slice( + &b_orig[i][j], col_copy_buffer, + (i * BITS_PER_BYTE + j) * num_aes_inputs, + (i * BITS_PER_BYTE + j + 1) * num_aes_inputs); + } + } + + CudaRadixCiphertextFFI *mul_workspace = + mem->round_workspaces->mix_columns_mul_workspace_buffer; + CudaRadixCiphertextFFI b_mul2[BYTES_PER_COLUMN][BITS_PER_BYTE]; + CudaRadixCiphertextFFI b_mul2_tmp_buffers[BYTES_PER_COLUMN]; + for (uint32_t i = 0; i < BYTES_PER_COLUMN; i++) { + as_radix_ciphertext_slice(&b_mul2_tmp_buffers[i], mul_workspace, + (i * BITS_PER_BYTE) * num_aes_inputs, + (i * BITS_PER_BYTE + BITS_PER_BYTE) * + num_aes_inputs); + for (uint32_t j = 0; j < BITS_PER_BYTE; j++) { + as_radix_ciphertext_slice(&b_mul2[i][j], &b_mul2_tmp_buffers[i], + j * num_aes_inputs, + (j + 1) * num_aes_inputs); + } + } + + for (uint32_t i = 0; i < BYTES_PER_COLUMN; ++i) { + vectorized_mul_by_2(streams, b_mul2[i], b_orig[i], mem); + } + aes_flush_inplace(streams, mul_workspace, mem, bsks, ksks); + + CudaRadixCiphertextFFI b0_mul2_copy_buffer; + as_radix_ciphertext_slice( + &b0_mul2_copy_buffer, mul_workspace, + (BYTES_PER_COLUMN * BITS_PER_BYTE) * num_aes_inputs, + ((BYTES_PER_COLUMN * BITS_PER_BYTE) + BITS_PER_BYTE) * num_aes_inputs); + CudaRadixCiphertextFFI b0_mul2_copy[BITS_PER_BYTE]; + for (uint32_t j = 0; j < BITS_PER_BYTE; j++) { + as_radix_ciphertext_slice(&b0_mul2_copy[j], &b0_mul2_copy_buffer, + j * num_aes_inputs, + (j + 1) * num_aes_inputs); + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &b0_mul2_copy[j], + &b_mul2[0][j]); + } + + for (uint32_t bit = 0; bit < BITS_PER_BYTE; bit++) { + CudaRadixCiphertextFFI *dest_bit_0 = + &s_bits[(col * BYTES_PER_COLUMN + 0) * BITS_PER_BYTE + bit]; + CudaRadixCiphertextFFI *dest_bit_1 = + &s_bits[(col * BYTES_PER_COLUMN + 1) * BITS_PER_BYTE + bit]; + CudaRadixCiphertextFFI *dest_bit_2 = + &s_bits[(col * BYTES_PER_COLUMN + 2) * BITS_PER_BYTE + bit]; + CudaRadixCiphertextFFI *dest_bit_3 = + &s_bits[(col * BYTES_PER_COLUMN + 3) * BITS_PER_BYTE + bit]; + +#define VEC_XOR_INPLACE(DEST, SRC) aes_xor(streams, mem, DEST, DEST, SRC) + + copy_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), dest_bit_0, &b_mul2[0][bit]); + VEC_XOR_INPLACE(dest_bit_0, &b_mul2[1][bit]); + VEC_XOR_INPLACE(dest_bit_0, &b_orig[1][bit]); + VEC_XOR_INPLACE(dest_bit_0, &b_orig[2][bit]); + VEC_XOR_INPLACE(dest_bit_0, &b_orig[3][bit]); + + copy_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), dest_bit_1, &b_orig[0][bit]); + VEC_XOR_INPLACE(dest_bit_1, &b_mul2[1][bit]); + VEC_XOR_INPLACE(dest_bit_1, &b_mul2[2][bit]); + VEC_XOR_INPLACE(dest_bit_1, &b_orig[2][bit]); + VEC_XOR_INPLACE(dest_bit_1, &b_orig[3][bit]); + + copy_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), dest_bit_2, &b_orig[0][bit]); + VEC_XOR_INPLACE(dest_bit_2, &b_orig[1][bit]); + VEC_XOR_INPLACE(dest_bit_2, &b_mul2[2][bit]); + VEC_XOR_INPLACE(dest_bit_2, &b_orig[3][bit]); + VEC_XOR_INPLACE(dest_bit_2, &b_mul2[3][bit]); + + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), dest_bit_3, + &b0_mul2_copy[bit]); + VEC_XOR_INPLACE(dest_bit_3, &b_orig[0][bit]); + VEC_XOR_INPLACE(dest_bit_3, &b_orig[1][bit]); + VEC_XOR_INPLACE(dest_bit_3, &b_orig[2][bit]); + VEC_XOR_INPLACE(dest_bit_3, &b_mul2[3][bit]); +#undef VEC_XOR_INPLACE + } + } +} + +/** + * The main AES encryption function. It orchestrates the full 10-round AES-128 + * encryption process on the bitsliced state. + * + * The process is broken down into three phases: + * + * 1. Initial Round (Round 0): + * - AddRoundKey, which is a XOR + * + * 2. Main Rounds (Rounds 1-9): + * This sequence is repeated 9 times. + * - SubBytes + * - ShiftRows + * - MixColumns + * - AddRoundKey + * + * 3. Final Round (Round 10): + * - SubBytes + * - ShiftRows + * - AddRoundKey + * + */ +template +__host__ void vectorized_aes_encrypt_inplace( + CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced, + CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t STATE_BYTES = 16; + constexpr uint32_t STATE_BITS = STATE_BYTES * BITS_PER_BYTE; + constexpr uint32_t ROUNDS = 10; + + CudaRadixCiphertextFFI *jit_transposed_key = + mem->main_workspaces->initial_states_and_jit_key_workspace; + + CudaRadixCiphertextFFI round_0_key_slice; + as_radix_ciphertext_slice( + &round_0_key_slice, (CudaRadixCiphertextFFI *)round_keys, 0, STATE_BITS); + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI tile_slice; + as_radix_ciphertext_slice( + &tile_slice, mem->main_workspaces->tmp_tiled_key_buffer, + block * STATE_BITS, (block + 1) * STATE_BITS); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &tile_slice, &round_0_key_slice); + } + transpose_blocks_to_bitsliced( + streams.stream(0), streams.gpu_index(0), jit_transposed_key, + mem->main_workspaces->tmp_tiled_key_buffer, num_aes_inputs, STATE_BITS); + + aes_xor(streams, mem, all_states_bitsliced, all_states_bitsliced, + jit_transposed_key); + + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + + for (uint32_t round = 1; round <= ROUNDS; ++round) { + CudaRadixCiphertextFFI s_bits[STATE_BITS]; + for (uint32_t i = 0; i < STATE_BITS; i++) { + as_radix_ciphertext_slice(&s_bits[i], all_states_bitsliced, + i * num_aes_inputs, + (i + 1) * num_aes_inputs); + } + + uint32_t sbox_parallelism = mem->sbox_parallel_instances; + switch (sbox_parallelism) { + case 1: + for (uint32_t i = 0; i < STATE_BYTES; ++i) { + CudaRadixCiphertextFFI *sbox_inputs[] = {&s_bits[i * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 1, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 2: + for (uint32_t i = 0; i < STATE_BYTES; i += 2) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 2, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 4: + for (uint32_t i = 0; i < STATE_BYTES; i += 4) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE], + &s_bits[(i + 2) * BITS_PER_BYTE], &s_bits[(i + 3) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 4, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 8: + for (uint32_t i = 0; i < STATE_BYTES; i += 8) { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[i * BITS_PER_BYTE], &s_bits[(i + 1) * BITS_PER_BYTE], + &s_bits[(i + 2) * BITS_PER_BYTE], &s_bits[(i + 3) * BITS_PER_BYTE], + &s_bits[(i + 4) * BITS_PER_BYTE], &s_bits[(i + 5) * BITS_PER_BYTE], + &s_bits[(i + 6) * BITS_PER_BYTE], &s_bits[(i + 7) * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 8, num_aes_inputs, + mem, bsks, ksks); + } + break; + case 16: { + CudaRadixCiphertextFFI *sbox_inputs[] = { + &s_bits[0 * BITS_PER_BYTE], &s_bits[1 * BITS_PER_BYTE], + &s_bits[2 * BITS_PER_BYTE], &s_bits[3 * BITS_PER_BYTE], + &s_bits[4 * BITS_PER_BYTE], &s_bits[5 * BITS_PER_BYTE], + &s_bits[6 * BITS_PER_BYTE], &s_bits[7 * BITS_PER_BYTE], + &s_bits[8 * BITS_PER_BYTE], &s_bits[9 * BITS_PER_BYTE], + &s_bits[10 * BITS_PER_BYTE], &s_bits[11 * BITS_PER_BYTE], + &s_bits[12 * BITS_PER_BYTE], &s_bits[13 * BITS_PER_BYTE], + &s_bits[14 * BITS_PER_BYTE], &s_bits[15 * BITS_PER_BYTE]}; + vectorized_sbox_n_bytes(streams, sbox_inputs, 16, num_aes_inputs, + mem, bsks, ksks); + } break; + default: + PANIC("Unsupported S-Box parallelism level selected: %u", + sbox_parallelism); + } + + vectorized_shift_rows(streams, all_states_bitsliced, num_aes_inputs, + mem); + + if (round != ROUNDS) { + vectorized_mix_columns(streams, s_bits, num_aes_inputs, mem, bsks, + ksks); + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + } + + CudaRadixCiphertextFFI round_key_slice; + as_radix_ciphertext_slice( + &round_key_slice, (CudaRadixCiphertextFFI *)round_keys, + round * STATE_BITS, (round + 1) * STATE_BITS); + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI tile_slice; + as_radix_ciphertext_slice( + &tile_slice, mem->main_workspaces->tmp_tiled_key_buffer, + block * STATE_BITS, (block + 1) * STATE_BITS); + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &tile_slice, + &round_key_slice); + } + transpose_blocks_to_bitsliced( + streams.stream(0), streams.gpu_index(0), jit_transposed_key, + mem->main_workspaces->tmp_tiled_key_buffer, num_aes_inputs, STATE_BITS); + + aes_xor(streams, mem, all_states_bitsliced, all_states_bitsliced, + jit_transposed_key); + + aes_flush_inplace(streams, all_states_bitsliced, mem, bsks, ksks); + } +} + +/** + * Performs the homomorphic addition of the plaintext counter to the encrypted + * IV. + * + * It functions as a 128-bit ripple-carry adder. For each bit $i$ from LSB to + * MSB, it computes the sum $S_i$ and the output carry $C_i$ based on the state + * bit ($IV_i$), the counter bit ($Counter_i$), and the incoming carry + * ($C_{i-1}$). The logical formulas are: + * + * $S_i = IV_i + Counter_i + C_{i-1}$ + * $C_i = (IV_i * Counter_i) + (IV_i * C_{i-1}) + (Counter_i * C_{i-1})$ + * + * The "transposed_states" buffer is updated in-place with the sum bits $S_i$. + * + */ +template +__host__ void vectorized_aes_full_adder_inplace( + CudaStreams streams, CudaRadixCiphertextFFI *transposed_states, + const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t NUM_BITS = 128; + + // --- Initialization --- + CudaRadixCiphertextFFI *carry_vec = + mem->counter_workspaces->vec_tmp_carry_buffer; + CudaRadixCiphertextFFI *trivial_b_bits_vec = + mem->counter_workspaces->vec_trivial_b_bits_buffer; + CudaRadixCiphertextFFI *sum_plus_carry_vec = + mem->counter_workspaces->vec_tmp_sum_buffer; + + // Initialize the carry vector to 0. + set_zero_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), carry_vec, 0, num_aes_inputs); + + // Main loop iterating over the 128 bits, from LSB (i=0) to MSB (i=127). + // Each iteration implements one stage of the full adder. + for (uint32_t i = 0; i < NUM_BITS; ++i) { + // The index in the state buffer is reversed (127-i), + // because of the LSB -> MSB logic. + const uint32_t state_bit_index = NUM_BITS - 1 - i; + + // --- Step 1: Prepare the adder inputs --- + + // a_i_vec: The first operand (ciphertext). This is the i-th bit of the IV. + CudaRadixCiphertextFFI a_i_vec; + as_radix_ciphertext_slice(&a_i_vec, transposed_states, + state_bit_index * num_aes_inputs, + (state_bit_index + 1) * num_aes_inputs); + + // Prepare the second operand (plaintext, then trivially encrypted). + // This is the i-th bit of the counter. + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + mem->counter_workspaces->h_counter_bits_buffer[block] = + counter_bits_le_all_blocks[block * NUM_BITS + i]; + } + cuda_memcpy_async_to_gpu(mem->counter_workspaces->d_counter_bits_buffer, + mem->counter_workspaces->h_counter_bits_buffer, + num_aes_inputs * sizeof(Torus), streams.stream(0), + streams.gpu_index(0)); + set_trivial_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), trivial_b_bits_vec, + mem->counter_workspaces->d_counter_bits_buffer, + mem->counter_workspaces->h_counter_bits_buffer, num_aes_inputs, + mem->params.message_modulus, mem->params.carry_modulus); + + // carry_vec: The third operand (ciphertext). + // This is the carry from the previous stage (C_{i-1}). + + // --- Step 2: Compute the sum and carry --- + + // Compute the temporary sum of the first two operands: IV_i + Counter_i + CudaRadixCiphertextFFI tmp_sum_vec; + as_radix_ciphertext_slice(&tmp_sum_vec, + mem->round_workspaces->vec_tmp_bit_buffer, + 0, num_aes_inputs); + aes_xor(streams, mem, &tmp_sum_vec, &a_i_vec, trivial_b_bits_vec); + + // Compute the sum of all three operands: (IV_i + Counter_i) + C_{i-1} + aes_xor(streams, mem, sum_plus_carry_vec, &tmp_sum_vec, carry_vec); + + // Compute the new carry (C_i) for the next iteration. + // The carry_lut applies the function f(x) = (x >> 1) & 1, which + // extracts the carry bit from the previous sum. The result is stored + // in carry_vec for the next iteration (i+1). + integer_radix_apply_univariate_lookup_table_kb( + streams, carry_vec, sum_plus_carry_vec, bsks, ksks, + mem->luts->carry_lut, num_aes_inputs); + + // Compute the final sum bit (S_i). + // The flush_lut applies the function f(x) = x & 1, which extracts + // the least significant bit of the sum. The result is written + // directly into the state buffer, updating the IV in-place. + integer_radix_apply_univariate_lookup_table_kb( + streams, &a_i_vec, sum_plus_carry_vec, bsks, ksks, mem->luts->flush_lut, + num_aes_inputs); + } +} + +/** + * Top-level function to perform a full AES-128-CTR encryption homomorphically. + * + * +----------+ +-------------------+ + * | IV_CT | | Plaintext Counter | + * +----------+ +-------------------+ + * | | + * V V + * +---------------------------------+ + * | Homomorphic Full Adder | + * | (IV_CT + Counter) | + * +---------------------------------+ + * | + * V + * +---------------------------------+ + * | Homomorphic AES Encryption | -> Final Output Ciphertext + * | (10 Rounds) | + * +---------------------------------+ + * + */ +template +__host__ void host_integer_aes_ctr_encrypt( + CudaStreams streams, CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys, + const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs, + int_aes_encrypt_buffer *mem, void *const *bsks, Torus *const *ksks) { + + constexpr uint32_t NUM_BITS = 128; + + CudaRadixCiphertextFFI *initial_states = + mem->main_workspaces->initial_states_and_jit_key_workspace; + + for (uint32_t block = 0; block < num_aes_inputs; ++block) { + CudaRadixCiphertextFFI output_slice; + as_radix_ciphertext_slice(&output_slice, initial_states, + block * NUM_BITS, (block + 1) * NUM_BITS); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &output_slice, iv); + } + + CudaRadixCiphertextFFI *transposed_states = + mem->main_workspaces->main_bitsliced_states_buffer; + transpose_blocks_to_bitsliced(streams.stream(0), streams.gpu_index(0), + transposed_states, initial_states, + num_aes_inputs, NUM_BITS); + + vectorized_aes_full_adder_inplace(streams, transposed_states, + counter_bits_le_all_blocks, + num_aes_inputs, mem, bsks, ksks); + + vectorized_aes_encrypt_inplace(streams, transposed_states, round_keys, + num_aes_inputs, mem, bsks, ksks); + + transpose_bitsliced_to_blocks(streams.stream(0), streams.gpu_index(0), + output, transposed_states, + num_aes_inputs, NUM_BITS); +} + +template +uint64_t scratch_cuda_integer_key_expansion( + CudaStreams streams, int_key_expansion_buffer **mem_ptr, + int_radix_params params, bool allocate_gpu_memory) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_key_expansion_buffer( + streams, params, allocate_gpu_memory, size_tracker); + return size_tracker; +} + +/** + * Homomorphically performs the AES-128 key expansion schedule on the GPU. + * + * This function expands an encrypted 128-bit key into 44 words (11 round keys). + * The generation logic for a new word `w_i` depends on its position: + * - If (i % 4 == 0): w_i = w_{i-4} + SubWord(RotWord(w_{i-1})) + Rcon[i/4] + * - If (i % 4 != 0): w_i = w_{i-4} + w_{i-1} + */ +template +__host__ void host_integer_key_expansion(CudaStreams streams, + CudaRadixCiphertextFFI *expanded_keys, + CudaRadixCiphertextFFI const *key, + int_key_expansion_buffer *mem, + void *const *bsks, + Torus *const *ksks) { + + constexpr uint32_t BITS_PER_WORD = 32; + constexpr uint32_t BITS_PER_BYTE = 8; + constexpr uint32_t BYTES_PER_WORD = 4; + constexpr uint32_t TOTAL_WORDS = 44; + constexpr uint32_t KEY_WORDS = 4; + + const Torus rcon[] = {0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36}; + + CudaRadixCiphertextFFI *words = mem->words_buffer; + + CudaRadixCiphertextFFI initial_key_dest_slice; + as_radix_ciphertext_slice(&initial_key_dest_slice, words, 0, + KEY_WORDS * BITS_PER_WORD); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &initial_key_dest_slice, key); + + for (uint32_t w = KEY_WORDS; w < TOTAL_WORDS; ++w) { + CudaRadixCiphertextFFI tmp_word_buffer, tmp_far, tmp_near; + + as_radix_ciphertext_slice(&tmp_word_buffer, mem->tmp_word_buffer, 0, + BITS_PER_WORD); + as_radix_ciphertext_slice(&tmp_far, words, (w - 4) * BITS_PER_WORD, + (w - 3) * BITS_PER_WORD); + as_radix_ciphertext_slice(&tmp_near, words, (w - 1) * BITS_PER_WORD, + w * BITS_PER_WORD); + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &tmp_word_buffer, &tmp_near); + + if (w % KEY_WORDS == 0) { + CudaRadixCiphertextFFI rotated_word_buffer; + as_radix_ciphertext_slice( + &rotated_word_buffer, mem->tmp_rotated_word_buffer, 0, BITS_PER_WORD); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), &rotated_word_buffer, 0, + BITS_PER_WORD - BITS_PER_BYTE, &tmp_word_buffer, BITS_PER_BYTE, + BITS_PER_WORD); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), &rotated_word_buffer, + BITS_PER_WORD - BITS_PER_BYTE, BITS_PER_WORD, &tmp_word_buffer, 0, + BITS_PER_BYTE); + + CudaRadixCiphertextFFI bit_slices[BITS_PER_WORD]; + for (uint32_t i = 0; i < BITS_PER_WORD; ++i) { + as_radix_ciphertext_slice(&bit_slices[i], &rotated_word_buffer, + i, i + 1); + } + + CudaRadixCiphertextFFI *sbox_byte_pointers[BYTES_PER_WORD]; + for (uint32_t i = 0; i < BYTES_PER_WORD; ++i) { + sbox_byte_pointers[i] = &bit_slices[i * BITS_PER_BYTE]; + } + + vectorized_sbox_n_bytes(streams, sbox_byte_pointers, + BYTES_PER_WORD, 1, mem->aes_encrypt_buffer, + bsks, ksks); + + Torus rcon_val = rcon[w / KEY_WORDS - 1]; + for (uint32_t bit = 0; bit < BITS_PER_BYTE; ++bit) { + if ((rcon_val >> (7 - bit)) & 1) { + CudaRadixCiphertextFFI first_byte_bit_slice; + as_radix_ciphertext_slice(&first_byte_bit_slice, + &rotated_word_buffer, bit, bit + 1); + host_integer_radix_add_scalar_one_inplace( + streams, &first_byte_bit_slice, mem->params.message_modulus, + mem->params.carry_modulus); + } + } + + aes_flush_inplace(streams, &rotated_word_buffer, mem->aes_encrypt_buffer, + bsks, ksks); + + copy_radix_ciphertext_async(streams.stream(0), + streams.gpu_index(0), &tmp_word_buffer, + &rotated_word_buffer); + } + + aes_xor(streams, mem->aes_encrypt_buffer, &tmp_word_buffer, &tmp_far, + &tmp_word_buffer); + aes_flush_inplace(streams, &tmp_word_buffer, mem->aes_encrypt_buffer, bsks, + ksks); + + CudaRadixCiphertextFFI dest_word; + as_radix_ciphertext_slice(&dest_word, words, w * BITS_PER_WORD, + (w + 1) * BITS_PER_WORD); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_word, &tmp_word_buffer); + } + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + expanded_keys, words); +} + +#endif diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index 75d7c963a..f8cb80860 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1729,6 +1729,78 @@ unsafe extern "C" { mem_ptr_void: *mut *mut i8, ); } +unsafe extern "C" { + pub fn scratch_cuda_integer_aes_encrypt_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + num_aes_inputs: u32, + sbox_parallelism: u32, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_integer_aes_ctr_encrypt_64( + streams: CudaStreamsFFI, + output: *mut CudaRadixCiphertextFFI, + iv: *const CudaRadixCiphertextFFI, + round_keys: *const CudaRadixCiphertextFFI, + counter_bits_le_all_blocks: *const u64, + num_aes_inputs: u32, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_integer_aes_encrypt_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); +} +unsafe extern "C" { + pub fn scratch_cuda_integer_key_expansion_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_integer_key_expansion_64( + streams: CudaStreamsFFI, + expanded_keys: *mut CudaRadixCiphertextFFI, + key: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_integer_key_expansion_64( + streams: CudaStreamsFFI, + mem_ptr_void: *mut *mut i8, + ); +} pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0; pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1; pub type KS_TYPE = ffi::c_uint; diff --git a/backends/tfhe-cuda-backend/wrapper.h b/backends/tfhe-cuda-backend/wrapper.h index 4ba1fbbf5..0b0a00023 100644 --- a/backends/tfhe-cuda-backend/wrapper.h +++ b/backends/tfhe-cuda-backend/wrapper.h @@ -2,6 +2,7 @@ #include "cuda/include/ciphertext.h" #include "cuda/include/integer/compression/compression.h" #include "cuda/include/integer/integer.h" +#include "cuda/include/aes/aes.h" #include "cuda/include/zk/zk.h" #include "cuda/include/keyswitch/keyswitch.h" #include "cuda/include/keyswitch/ks_enums.h" diff --git a/tfhe-benchmark/benches/integer/aes.rs b/tfhe-benchmark/benches/integer/aes.rs new file mode 100644 index 000000000..5c5228d19 --- /dev/null +++ b/tfhe-benchmark/benches/integer/aes.rs @@ -0,0 +1,154 @@ +#[cfg(feature = "gpu")] +pub mod cuda { + use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + use benchmark::utilities::{get_bench_type, write_to_json, BenchmarkType, OperatorType}; + use criterion::{black_box, Criterion, Throughput}; + use tfhe::core_crypto::gpu::CudaStreams; + use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + use tfhe::integer::gpu::CudaServerKey; + use tfhe::integer::keycache::KEY_CACHE; + use tfhe::integer::{IntegerKeyKind, RadixClientKey}; + use tfhe::keycache::NamedParam; + use tfhe::shortint::AtomicPatternParameters; + + pub fn cuda_aes(c: &mut Criterion) { + let bench_name = "integer::cuda::aes"; + + let mut bench_group = c.benchmark_group(bench_name); + bench_group + .sample_size(15) + .measurement_time(std::time::Duration::from_secs(60)) + .warm_up_time(std::time::Duration::from_secs(60)); + + let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let atomic_param: AtomicPatternParameters = param.into(); + + let key: u128 = 0x2b7e151628aed2a6abf7158809cf4f3c; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + let aes_op_bit_size = 128; + + let param_name = param.name(); + + match get_bench_type() { + BenchmarkType::Latency => { + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); + + let ct_key = cks.encrypt_u128_for_aes_ctr(key); + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + + { + const NUM_AES_INPUTS: usize = 1; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption"); + + let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) }; + streams.synchronize(); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + unsafe { + black_box(sks.aes_encrypt_async( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + } + streams.synchronize(); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_encryption", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } + + { + let bench_id = format!("{param_name}::key_expansion"); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + unsafe { + black_box(sks.key_expansion_async(&d_key, &streams)); + } + streams.synchronize(); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_key_expansion", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } + } + BenchmarkType::Throughput => { + const NUM_AES_INPUTS: usize = 192; + const SBOX_PARALLELISM: usize = 16; + let bench_id = format!("throughput::{param_name}::{NUM_AES_INPUTS}_inputs"); + + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); + + bench_group.throughput(Throughput::Elements(NUM_AES_INPUTS as u64)); + + let ct_key = cks.encrypt_u128_for_aes_ctr(key); + let ct_iv = cks.encrypt_u128_for_aes_ctr(iv); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + + let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) }; + streams.synchronize(); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + unsafe { + black_box(sks.aes_encrypt_async( + &d_iv, + &round_keys, + 0, + NUM_AES_INPUTS, + SBOX_PARALLELISM, + &streams, + )); + } + streams.synchronize(); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "aes_encryption", + &OperatorType::Atomic, + aes_op_bit_size, + vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize], + ); + } + }; + + bench_group.finish(); + } +} diff --git a/tfhe-benchmark/benches/integer/bench.rs b/tfhe-benchmark/benches/integer/bench.rs index 668e19f7e..c311fed0d 100644 --- a/tfhe-benchmark/benches/integer/bench.rs +++ b/tfhe-benchmark/benches/integer/bench.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] +mod aes; mod oprf; use benchmark::params::ParamsAndNumBlocksIter; @@ -2795,6 +2796,7 @@ mod cuda { cuda_trailing_ones, cuda_ilog2, oprf::cuda::cuda_unsigned_oprf, + aes::cuda::cuda_aes, ); criterion_group!( @@ -2823,6 +2825,7 @@ mod cuda { cuda_scalar_div, cuda_scalar_rem, oprf::cuda::cuda_unsigned_oprf, + aes::cuda::cuda_aes, ); criterion_group!( diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 6c474c55a..ab9a2ea2f 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -7403,3 +7403,289 @@ pub unsafe fn expand_async( ); cleanup_expand_without_verification_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); } + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn unchecked_aes_ctr_encrypt_integer_radix_kb_assign_async< + T: UnsignedInteger, + B: Numeric, +>( + streams: &CudaStreams, + output: &mut CudaRadixCiphertext, + iv: &CudaRadixCiphertext, + round_keys: &CudaRadixCiphertext, + start_counter: u128, + num_aes_inputs: u32, + sbox_parallelism: u32, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let mut cuda_ffi_output = + prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels); + + let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels); + + let mut round_keys_degrees = round_keys.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut round_keys_noise_levels = round_keys + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let cuda_ffi_round_keys = prepare_cuda_radix_ffi( + round_keys, + &mut round_keys_degrees, + &mut round_keys_noise_levels, + ); + + let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); + + let counter_bits_le: Vec = (0..num_aes_inputs) + .flat_map(|i| { + let current_counter = start_counter + i as u128; + (0..128).map(move |bit_index| ((current_counter >> bit_index) & 1) as u64) + }) + .collect(); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_aes_encrypt_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + num_aes_inputs, + sbox_parallelism, + ); + + cuda_integer_aes_ctr_encrypt_64( + streams.ffi(), + &raw mut cuda_ffi_output, + &raw const cuda_ffi_iv, + &raw const cuda_ffi_round_keys, + counter_bits_le.as_ptr(), + num_aes_inputs, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_integer_aes_encrypt_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(output, &cuda_ffi_output); +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn get_aes_ctr_encrypt_integer_radix_size_on_gpu( + streams: &CudaStreams, + num_aes_inputs: u32, + sbox_parallelism: u32, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) -> u64 { + let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + let size = unsafe { + scratch_cuda_integer_aes_encrypt_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + false, + noise_reduction_type as u32, + num_aes_inputs, + sbox_parallelism, + ) + }; + + unsafe { cleanup_cuda_integer_aes_encrypt_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)) }; + + size +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn unchecked_key_expansion_integer_radix_kb_assign_async< + T: UnsignedInteger, + B: Numeric, +>( + streams: &CudaStreams, + expanded_keys: &mut CudaRadixCiphertext, + key: &CudaRadixCiphertext, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let mut expanded_keys_degrees = expanded_keys + .info + .blocks + .iter() + .map(|b| b.degree.0) + .collect(); + let mut expanded_keys_noise_levels = expanded_keys + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut cuda_ffi_expanded_keys = prepare_cuda_radix_ffi( + expanded_keys, + &mut expanded_keys_degrees, + &mut expanded_keys_noise_levels, + ); + + let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels); + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_key_expansion_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_integer_key_expansion_64( + streams.ffi(), + &raw mut cuda_ffi_expanded_keys, + &raw const cuda_ffi_key, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_integer_key_expansion_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(expanded_keys, &cuda_ffi_expanded_keys); +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn get_key_expansion_integer_radix_size_on_gpu( + streams: &CudaStreams, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) -> u64 { + let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + let size = { + scratch_cuda_integer_key_expansion_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ) + }; + + unsafe { + cleanup_cuda_integer_key_expansion_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)) + }; + + size +} diff --git a/tfhe/src/integer/gpu/server_key/radix/aes.rs b/tfhe/src/integer/gpu/server_key/radix/aes.rs new file mode 100644 index 000000000..5f85152f5 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/aes.rs @@ -0,0 +1,465 @@ +use crate::core_crypto::gpu::{ + check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams, +}; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; + +use crate::core_crypto::prelude::LweBskGroupingFactor; +use crate::integer::gpu::{ + get_aes_ctr_encrypt_integer_radix_size_on_gpu, get_key_expansion_integer_radix_size_on_gpu, + unchecked_aes_ctr_encrypt_integer_radix_kb_assign_async, + unchecked_key_expansion_integer_radix_kb_assign_async, PBSType, +}; +use crate::integer::{RadixCiphertext, RadixClientKey}; +use crate::shortint::Ciphertext; + +const NUM_BITS: usize = 128; + +impl RadixClientKey { + /// Encrypts a 128-bit block for homomorphic AES evaluation. + /// + /// This function prepares a 128-bit plaintext block (like an AES key or IV) + /// for homomorphic processing by decomposing it into its 128 constituent bits + /// and encrypting each bit individually with FHE. + /// + /// The process is as follows: + /// ```text + /// // INPUT: A 128-bit plaintext block + /// Plaintext block (u128): 0x2b7e1516... + /// | + /// V + /// // 1. Decompose the block into individual bits + /// Individual bits: [b127, b126, ..., b1, b0] + /// | + /// V + /// // 2. Encrypt each bit individually using FHE + /// `self.encrypt(bit)` is applied to each bit + /// | + /// V + /// // 3. Collect the resulting bit-ciphertexts + /// Ciphertexts: [Ct(b127), Ct(b126), ..., Ct(b0)] + /// | + /// V + /// // 4. Group the bit-ciphertexts into a single RadixCiphertext + /// // representing the full encrypted block. + /// // OUTPUT: A RadixCiphertext + /// ``` + pub fn encrypt_u128_for_aes_ctr(&self, data: u128) -> RadixCiphertext { + let mut blocks: Vec = Vec::with_capacity(NUM_BITS); + for i in 0..NUM_BITS { + let bit = ((data >> (NUM_BITS - 1 - i)) & 1) as u64; + blocks.extend(self.encrypt(bit).blocks); + } + RadixCiphertext::from(blocks) + } + + /// Decrypts a `RadixCiphertext` containing one or more 128-bit blocks + /// that were homomorphically processed. + /// + /// This function reverses the encryption process by decrypting each individual + /// bit-ciphertext and reassembling them into 128-bit plaintext blocks. + /// + /// The process is as follows: + /// ```text + /// // INPUT: RadixCiphertext containing one or more encrypted blocks + /// Ciphertext collection: [Ct(b127), ..., Ct(b0), Ct(b'127), ..., Ct(b'0), ...] + /// | + /// | (For each sequence of 128 bit-ciphertexts) + /// V + /// // 1. Decrypt each bit's ciphertext individually + /// `self.decrypt(Ct)` is applied to each bit-ciphertext + /// | + /// V + /// // 2. Collect the resulting plaintext bits + /// Plaintext bits: [b127, b126, ..., b0] + /// | + /// V + /// // 3. Assemble the bits back into a 128-bit block + /// Reconstruction: ( ...((b127 << 1) | b126) << 1 | ... ) | b0 + /// | + /// V + /// // OUTPUT: A vector of plaintext u128 blocks + /// Plaintext u128s: [0x..., ...] + /// ``` + pub fn decrypt_u128_from_aes_ctr( + &self, + encrypted_result: &RadixCiphertext, + num_aes_inputs: usize, + ) -> Vec { + let mut plaintext_results = Vec::with_capacity(num_aes_inputs); + for i in 0..num_aes_inputs { + let mut current_block_plaintext: u128 = 0; + let block_start_index = i * NUM_BITS; + for j in 0..NUM_BITS { + let block_slice = + &encrypted_result.blocks[block_start_index + j..block_start_index + j + 1]; + let block_radix_ct = RadixCiphertext::from(block_slice.to_vec()); + let decrypted_bit: u128 = self.decrypt(&block_radix_ct); + current_block_plaintext = (current_block_plaintext << 1) | decrypted_bit; + } + plaintext_results.push(current_block_plaintext); + } + plaintext_results + } +} + +impl CudaServerKey { + pub fn aes_ctr( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let gpu_index = streams.gpu_indexes[0]; + + let key_expansion_size = self.get_key_expansion_size_on_gpu(streams); + check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index); + + // `parallelism` refers to level of parallelization of the S-box. + // S-box should process 16 bytes of data: sequentially, or in groups of 2, + // or in groups of 4, or in groups of 8, or all 16 at the same time. + // More parallelization leads to higher memory usage. Therefore, we must find a way + // to maximize parallelization while ensuring that there is still enough memory remaining on + // the GPU. + // + let mut parallelism = 16; + + while parallelism > 0 { + // `num_aes_inputs` refers to the number of 128-bit ciphertexts that AES will produce. + // + let aes_encrypt_size = + self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams); + + if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) { + let round_keys = unsafe { self.key_expansion_async(key, streams) }; + let res = unsafe { + self.aes_encrypt_async( + iv, + &round_keys, + start_counter, + num_aes_inputs, + parallelism, + streams, + ) + }; + streams.synchronize(); + return res; + } + parallelism /= 2; + } + + panic!("Failed to allocate GPU memory for AES, even with the lowest parallelism setting."); + } + + pub fn aes_ctr_with_fixed_parallelism( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + assert!( + [1, 2, 4, 8, 16].contains(&sbox_parallelism), + "Invalid S-Box parallelism: must be one of [1, 2, 4, 8, 16], got {sbox_parallelism}" + ); + + let gpu_index = streams.gpu_indexes[0]; + + let key_expansion_size = self.get_key_expansion_size_on_gpu(streams); + check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index); + + let aes_encrypt_size = + self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams); + check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index); + + let round_keys = unsafe { self.key_expansion_async(key, streams) }; + let res = unsafe { + self.aes_encrypt_async( + iv, + &round_keys, + start_counter, + num_aes_inputs, + sbox_parallelism, + streams, + ) + }; + streams.synchronize(); + res + } + + /// # Safety + /// + /// - [CudaStreams::synchronize] __must__ be called after this function as soon as + /// synchronization is required + pub unsafe fn aes_encrypt_async( + &self, + iv: &CudaUnsignedRadixCiphertext, + round_keys: &CudaUnsignedRadixCiphertext, + start_counter: u128, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let mut result: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_aes_inputs * 128, streams); + + let num_round_key_blocks = 11 * NUM_BITS; + + assert_eq!( + iv.as_ref().d_blocks.lwe_ciphertext_count().0, + NUM_BITS, + "AES IV must contain {NUM_BITS} encrypted bits, but contains {}", + iv.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + round_keys.as_ref().d_blocks.lwe_ciphertext_count().0, + num_round_key_blocks, + "AES round_keys must contain {num_round_key_blocks} encrypted bits, but contains {}", + round_keys.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + result.as_ref().d_blocks.lwe_ciphertext_count().0, + num_aes_inputs * 128, + "AES result must contain {} encrypted bits for {num_aes_inputs} blocks, but contains {}", + num_aes_inputs * 128, + result.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + unchecked_aes_ctr_encrypt_integer_radix_kb_assign_async( + streams, + result.as_mut(), + iv.as_ref(), + round_keys.as_ref(), + start_counter, + num_aes_inputs as u32, + sbox_parallelism as u32, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + unchecked_aes_ctr_encrypt_integer_radix_kb_assign_async( + streams, + result.as_mut(), + iv.as_ref(), + round_keys.as_ref(), + start_counter, + num_aes_inputs as u32, + sbox_parallelism as u32, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ); + } + } + result + } + + fn get_aes_encrypt_size_on_gpu( + &self, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> u64 { + let size = unsafe { + self.get_aes_encrypt_size_on_gpu_async(num_aes_inputs, sbox_parallelism, streams) + }; + streams.synchronize(); + size + } + + /// # Safety + /// + /// - [CudaStreams::synchronize] __must__ be called after this function as soon as + /// synchronization is required + unsafe fn get_aes_encrypt_size_on_gpu_async( + &self, + num_aes_inputs: usize, + sbox_parallelism: usize, + streams: &CudaStreams, + ) -> u64 { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => get_aes_ctr_encrypt_integer_radix_size_on_gpu( + streams, + num_aes_inputs as u32, + sbox_parallelism as u32, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ), + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + get_aes_ctr_encrypt_integer_radix_size_on_gpu( + streams, + num_aes_inputs as u32, + sbox_parallelism as u32, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ) + } + } + } + + /// # Safety + /// + /// - [CudaStreams::synchronize] __must__ be called after this function as soon as + /// synchronization is required + pub unsafe fn key_expansion_async( + &self, + key: &CudaUnsignedRadixCiphertext, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let num_round_keys = 11; + let num_key_bits = 128; + let mut expanded_keys: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_round_keys * num_key_bits, streams); + + assert_eq!( + key.as_ref().d_blocks.lwe_ciphertext_count().0, + num_key_bits, + "Input key must contain {} encrypted bits, but contains {}", + num_key_bits, + key.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + unchecked_key_expansion_integer_radix_kb_assign_async( + streams, + expanded_keys.as_mut(), + key.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + unchecked_key_expansion_integer_radix_kb_assign_async( + streams, + expanded_keys.as_mut(), + key.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ); + } + } + expanded_keys + } + + fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 { + let size = unsafe { self.get_key_expansion_size_on_gpu_async(streams) }; + streams.synchronize(); + size + } + + /// # Safety + /// + /// - [CudaStreams::synchronize] __must__ be called after this function as soon as + /// synchronization is required + unsafe fn get_key_expansion_size_on_gpu_async(&self, streams: &CudaStreams) -> u64 { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => get_key_expansion_integer_radix_size_on_gpu( + streams, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ), + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + get_key_expansion_integer_radix_size_on_gpu( + streams, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + ) + } + } + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index c74587730..4e390caaf 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -57,6 +57,7 @@ mod sub; mod vector_comparisons; mod vector_find; +mod aes; #[cfg(test)] mod tests_long_run; #[cfg(test)] diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index d2382d6a7..f0954837d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod test_add; +pub(crate) mod test_aes; pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; @@ -82,6 +83,98 @@ impl GpuFunctionExecutor { } } +impl<'a, F> + FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + > for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &CudaUnsignedRadixCiphertext, + u128, + usize, + usize, + &CudaStreams, + ) -> CudaUnsignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + ) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + let d_ctxt_2 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams); + + let gpu_result = (self.func)( + &context.sks, + &d_ctxt_1, + &d_ctxt_2, + input.2, + input.3, + input.4, + &context.streams, + ); + + gpu_result.to_radix_ciphertext(&context.streams) + } +} + +impl<'a, F> + FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, u128, usize), RadixCiphertext> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &CudaUnsignedRadixCiphertext, + u128, + usize, + &CudaStreams, + ) -> CudaUnsignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize), + ) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + let d_ctxt_2 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams); + + let gpu_result = (self.func)( + &context.sks, + &d_ctxt_1, + &d_ctxt_2, + input.2, + input.3, + &context.streams, + ); + + gpu_result.to_radix_ciphertext(&context.streams) + } +} + /// For default/unchecked binary functions impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext> for GpuFunctionExecutor diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs new file mode 100644 index 000000000..6d7e72486 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_aes.rs @@ -0,0 +1,59 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parameterized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + aes_dynamic_parallelism_many_inputs_test, aes_fixed_parallelism_1_input_test, + aes_fixed_parallelism_2_inputs_test, +}; +use crate::shortint::parameters::{ + TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, +}; + +create_gpu_parameterized_test!(integer_aes_fixed_parallelism_1_input { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_aes_fixed_parallelism_2_inputs { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_aes_dynamic_parallelism_many_inputs { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +// The following two tests are referred to as "fixed_parallelism" because the objective is to test +// AES, in CTR mode, across all possible parallelizations of the S-box. The S-box must process 16 +// bytes; the parallelization refers to the number of bytes it will process in parallel in one call: +// 1, 2, 4, 8, or 16. +// +fn integer_aes_fixed_parallelism_1_input

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr_with_fixed_parallelism); + aes_fixed_parallelism_1_input_test(param, executor); +} + +fn integer_aes_fixed_parallelism_2_inputs

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr_with_fixed_parallelism); + aes_fixed_parallelism_2_inputs_test(param, executor); +} + +// The test referred to as "dynamic_parallelism" will seek the maximum s-box parallelization that +// the machine can support. +// +fn integer_aes_dynamic_parallelism_many_inputs

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::aes_ctr); + aes_dynamic_parallelism_many_inputs_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index c9327450b..b597c076b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -43,6 +43,11 @@ pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_ default_add_test, unchecked_add_assign_test, }; #[cfg(feature = "gpu")] +pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_aes::{ + aes_dynamic_parallelism_many_inputs_test, aes_fixed_parallelism_1_input_test, + aes_fixed_parallelism_2_inputs_test, +}; +#[cfg(feature = "gpu")] pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::default_neg_test; pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::unchecked_neg_test; #[cfg(feature = "gpu")] diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index a3ce79cc9..ebf45c3fa 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -1,5 +1,6 @@ mod modulus_switch_compression; pub(crate) mod test_add; +pub(crate) mod test_aes; pub(crate) mod test_bitwise_op; mod test_block_rotate; mod test_block_shift; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs new file mode 100644 index 000000000..86f3bd731 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_aes.rs @@ -0,0 +1,225 @@ +#![cfg(feature = "gpu")] + +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey}; +use crate::shortint::parameters::TestParameters; +use std::sync::Arc; + +const S_BOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +fn plain_key_expansion(key: u128) -> Vec { + const RCON: [u32; 10] = [ + 0x01000000, 0x02000000, 0x04000000, 0x08000000, 0x10000000, 0x20000000, 0x40000000, + 0x80000000, 0x1B000000, 0x36000000, + ]; + let mut words = [0u32; 44]; + for (i, word) in words.iter_mut().enumerate().take(4) { + *word = (key >> (96 - (i * 32))) as u32; + } + for i in 4..44 { + let mut temp = words[i - 1]; + if i % 4 == 0 { + temp = temp.rotate_left(8); + let mut sub_bytes = 0u32; + for j in 0..4 { + let byte = (temp >> (24 - j * 8)) as u8; + sub_bytes |= (S_BOX[byte as usize] as u32) << (24 - j * 8); + } + temp = sub_bytes ^ RCON[i / 4 - 1]; + } + words[i] = words[i - 4] ^ temp; + } + words + .chunks_exact(4) + .map(|chunk| { + ((chunk[0] as u128) << 96) + | ((chunk[1] as u128) << 64) + | ((chunk[2] as u128) << 32) + | (chunk[3] as u128) + }) + .collect() +} +fn sub_bytes(state: &mut [u8; 16]) { + for byte in state.iter_mut() { + *byte = S_BOX[*byte as usize]; + } +} +fn shift_rows(state: &mut [u8; 16]) { + let original = *state; + state[1] = original[5]; + state[5] = original[9]; + state[9] = original[13]; + state[13] = original[1]; + state[2] = original[10]; + state[6] = original[14]; + state[10] = original[2]; + state[14] = original[6]; + state[3] = original[15]; + state[7] = original[3]; + state[11] = original[7]; + state[15] = original[11]; +} +fn gmul(mut a: u8, mut b: u8) -> u8 { + let mut p = 0; + for _ in 0..8 { + if (b & 1) != 0 { + p ^= a; + } + let hi_bit_set = (a & 0x80) != 0; + a <<= 1; + if hi_bit_set { + a ^= 0x1B; + } + b >>= 1; + } + p +} +fn mix_columns(state: &mut [u8; 16]) { + let original = *state; + for i in 0..4 { + let col = i * 4; + state[col] = gmul(original[col], 2) + ^ gmul(original[col + 1], 3) + ^ original[col + 2] + ^ original[col + 3]; + state[col + 1] = original[col] + ^ gmul(original[col + 1], 2) + ^ gmul(original[col + 2], 3) + ^ original[col + 3]; + state[col + 2] = original[col] + ^ original[col + 1] + ^ gmul(original[col + 2], 2) + ^ gmul(original[col + 3], 3); + state[col + 3] = gmul(original[col], 3) + ^ original[col + 1] + ^ original[col + 2] + ^ gmul(original[col + 3], 2); + } +} +fn add_round_key(state: &mut [u8; 16], round_key: u128) { + let key_bytes = round_key.to_be_bytes(); + for i in 0..16 { + state[i] ^= key_bytes[i]; + } +} +fn plain_aes_encrypt_block(block_bytes: &mut [u8; 16], expanded_keys: &[u128]) { + add_round_key(block_bytes, expanded_keys[0]); + for round_key in expanded_keys.iter().take(10).skip(1) { + sub_bytes(block_bytes); + shift_rows(block_bytes); + mix_columns(block_bytes); + add_round_key(block_bytes, *round_key); + } + sub_bytes(block_bytes); + shift_rows(block_bytes); + add_round_key(block_bytes, expanded_keys[10]); +} +fn plain_aes_ctr(num_aes_inputs: usize, iv: u128, key: u128) -> Vec { + let expanded_keys = plain_key_expansion(key); + let mut results = Vec::with_capacity(num_aes_inputs); + for i in 0..num_aes_inputs { + let counter_value = iv.wrapping_add(i as u128); + let mut block = counter_value.to_be_bytes(); + plain_aes_encrypt_block(&mut block, &expanded_keys); + results.push(u128::from_be_bytes(block)); + } + results +} + +fn internal_aes_fixed_parallelism_test(param: P, mut executor: E, num_aes_inputs: usize) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key: u128 = 0x2b7e151628aed2a6abf7158809cf4f3c; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + + let plain_results = plain_aes_ctr(num_aes_inputs, iv, key); + + let ctxt_key = cks.encrypt_u128_for_aes_ctr(key); + let ctxt_iv = cks.encrypt_u128_for_aes_ctr(iv); + + for sbox_parallelism in [1, 2, 4, 8, 16] { + let encrypted_result = + executor.execute((&ctxt_key, &ctxt_iv, 0, num_aes_inputs, sbox_parallelism)); + let fhe_results = cks.decrypt_u128_from_aes_ctr(&encrypted_result, num_aes_inputs); + assert_eq!(fhe_results, plain_results); + } +} + +pub fn aes_fixed_parallelism_1_input_test(param: P, executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + internal_aes_fixed_parallelism_test(param, executor, 1); +} + +pub fn aes_fixed_parallelism_2_inputs_test(param: P, executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), + RadixCiphertext, + >, +{ + internal_aes_fixed_parallelism_test(param, executor, 2); +} + +pub fn aes_dynamic_parallelism_many_inputs_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize), + RadixCiphertext, + >, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key: u128 = 0x2b7e151628aed2a6abf7158809cf4f3c; + let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff; + + let ctxt_key = cks.encrypt_u128_for_aes_ctr(key); + let ctxt_iv = cks.encrypt_u128_for_aes_ctr(iv); + + for num_aes_inputs in [4, 8, 16, 32] { + let plain_results = plain_aes_ctr(num_aes_inputs, iv, key); + let encrypted_result = executor.execute((&ctxt_key, &ctxt_iv, 0, num_aes_inputs)); + let fhe_results = cks.decrypt_u128_from_aes_ctr(&encrypted_result, num_aes_inputs); + assert_eq!(fhe_results, plain_results); + } +}