From 8aeddb6b6ae5fc1f3f1fc0db0877823cf9b2bbcd Mon Sep 17 00:00:00 2001 From: Enzo Di Maria Date: Fri, 19 Dec 2025 14:54:54 +0100 Subject: [PATCH] feat(gpu): trivium --- Makefile | 7 + backends/tfhe-cuda-backend/build.rs | 1 + .../cuda/include/trivium/trivium.h | 24 ++ .../cuda/include/trivium/trivium_utilities.h | 277 ++++++++++++++++ .../cuda/src/trivium/trivium.cu | 45 +++ .../cuda/src/trivium/trivium.cuh | 311 ++++++++++++++++++ backends/tfhe-cuda-backend/src/bindings.rs | 36 ++ backends/tfhe-cuda-backend/wrapper.h | 1 + tfhe-benchmark/Cargo.toml | 6 + tfhe-benchmark/benches/integer/bench.rs | 1 + tfhe-benchmark/benches/integer/trivium.rs | 100 ++++++ tfhe-benchmark/src/utilities.rs | 2 - tfhe/src/integer/gpu/mod.rs | 95 ++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 1 + .../server_key/radix/tests_unsigned/mod.rs | 42 +++ .../radix/tests_unsigned/test_trivium.rs | 71 ++++ .../integer/gpu/server_key/radix/trivium.rs | 176 ++++++++++ .../radix_parallel/tests_unsigned/mod.rs | 1 + .../tests_unsigned/test_trivium.rs | 260 +++++++++++++++ 19 files changed, 1455 insertions(+), 2 deletions(-) create mode 100644 backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h create mode 100644 backends/tfhe-cuda-backend/cuda/include/trivium/trivium_utilities.h create mode 100644 backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu create mode 100644 backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh create mode 100644 tfhe-benchmark/benches/integer/trivium.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_trivium.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/trivium.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_trivium.rs diff --git a/Makefile b/Makefile index bfb84e2fd..4555f2897 100644 --- a/Makefile +++ b/Makefile @@ -1452,6 +1452,13 @@ bench_integer_aes256_gpu: install_rs_check_toolchain --bench integer-aes256 \ --features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off -- +.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend +bench_integer_trivium_gpu: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \ + cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \ + --bench integer-trivium \ + --features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off -- + .PHONY: bench_integer_multi_bit # Run benchmarks for unsigned integer using multi-bit parameters bench_integer_multi_bit: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_PARAM_TYPE=MULTI_BIT __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \ diff --git a/backends/tfhe-cuda-backend/build.rs b/backends/tfhe-cuda-backend/build.rs index b066ab83e..00676c109 100644 --- a/backends/tfhe-cuda-backend/build.rs +++ b/backends/tfhe-cuda-backend/build.rs @@ -86,6 +86,7 @@ fn main() { "cuda/include/integer/integer.h", "cuda/include/integer/rerand.h", "cuda/include/aes/aes.h", + "cuda/include/trivium/trivium.h", "cuda/include/zk/zk.h", "cuda/include/keyswitch/keyswitch.h", "cuda/include/keyswitch/ks_enums.h", diff --git a/backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h b/backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h new file mode 100644 index 000000000..a23b005ef --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h @@ -0,0 +1,24 @@ +#ifndef TRIVIUM_H +#define TRIVIUM_H + +#include "../integer/integer.h" + +extern "C" { +uint64_t scratch_cuda_trivium_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs); + +void cuda_trivium_generate_keystream_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output, + const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv, + uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); +} + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/include/trivium/trivium_utilities.h b/backends/tfhe-cuda-backend/cuda/include/trivium/trivium_utilities.h new file mode 100644 index 000000000..acbe004b0 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/trivium/trivium_utilities.h @@ -0,0 +1,277 @@ +#ifndef TRIVIUM_UTILITIES_H +#define TRIVIUM_UTILITIES_H +#include "../integer/integer_utilities.h" + +template struct int_trivium_lut_buffers { + int_radix_lut *and_lut; + int_radix_lut *flush_lut; + + int_trivium_lut_buffers(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_trivium_inputs, + uint64_t &size_tracker) { + + constexpr uint32_t BATCH_SIZE = 64; + constexpr uint32_t MAX_AND_PER_STEP = 3; + uint32_t total_lut_ops = num_trivium_inputs * BATCH_SIZE * MAX_AND_PER_STEP; + + this->and_lut = new int_radix_lut(streams, params, 1, total_lut_ops, + allocate_gpu_memory, size_tracker); + + std::function and_lambda = + [](Torus a, Torus b) -> Torus { return a & b; }; + + generate_device_accumulator_bivariate( + streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0), + this->and_lut->get_degree(0), this->and_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, and_lambda, allocate_gpu_memory); + + auto active_streams_and = + streams.active_gpu_subset(total_lut_ops, params.pbs_type); + this->and_lut->broadcast_lut(active_streams_and); + this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker); + + uint32_t total_flush_ops = num_trivium_inputs * BATCH_SIZE * 4; + + this->flush_lut = new int_radix_lut( + streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker); + + std::function flush_lambda = [](Torus x) -> Torus { + return x & 1; + }; + + generate_device_accumulator( + streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0), + this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, flush_lambda, allocate_gpu_memory); + + auto active_streams_flush = + streams.active_gpu_subset(total_flush_ops, params.pbs_type); + this->flush_lut->broadcast_lut(active_streams_flush); + this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker); + } + + void release(CudaStreams streams) { + this->and_lut->release(streams); + delete this->and_lut; + this->and_lut = nullptr; + + this->flush_lut->release(streams); + delete this->flush_lut; + this->flush_lut = nullptr; + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +template struct int_trivium_state_workspaces { + CudaRadixCiphertextFFI *a_reg; + CudaRadixCiphertextFFI *b_reg; + CudaRadixCiphertextFFI *c_reg; + + CudaRadixCiphertextFFI *shift_workspace; + + CudaRadixCiphertextFFI *temp_t1; + CudaRadixCiphertextFFI *temp_t2; + CudaRadixCiphertextFFI *temp_t3; + CudaRadixCiphertextFFI *new_a; + CudaRadixCiphertextFFI *new_b; + CudaRadixCiphertextFFI *new_c; + + CudaRadixCiphertextFFI *packed_pbs_lhs; + CudaRadixCiphertextFFI *packed_pbs_rhs; + CudaRadixCiphertextFFI *packed_pbs_out; + + CudaRadixCiphertextFFI *packed_flush_in; + CudaRadixCiphertextFFI *packed_flush_out; + + int_trivium_state_workspaces(CudaStreams streams, + const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_inputs, + uint64_t &size_tracker) { + + this->a_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->b_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->c_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->shift_workspace = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->shift_workspace, + 128 * num_inputs, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + uint32_t batch_blocks = 64 * num_inputs; + + this->temp_t1 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->temp_t2 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->temp_t3 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_a = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_b = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_c = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->packed_pbs_lhs = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_pbs_rhs = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_pbs_out = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_out, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_flush_in = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_flush_in, + 4 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_flush_out = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_flush_out, + 4 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + } + + void release(CudaStreams streams, bool allocate_gpu_memory) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->a_reg, allocate_gpu_memory); + delete this->a_reg; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->b_reg, allocate_gpu_memory); + delete this->b_reg; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->c_reg, allocate_gpu_memory); + delete this->c_reg; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->shift_workspace, allocate_gpu_memory); + delete this->shift_workspace; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->temp_t1, allocate_gpu_memory); + delete this->temp_t1; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->temp_t2, allocate_gpu_memory); + delete this->temp_t2; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->temp_t3, allocate_gpu_memory); + delete this->temp_t3; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->new_a, allocate_gpu_memory); + delete this->new_a; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->new_b, allocate_gpu_memory); + delete this->new_b; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->new_c, allocate_gpu_memory); + delete this->new_c; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->packed_pbs_lhs, allocate_gpu_memory); + delete this->packed_pbs_lhs; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->packed_pbs_rhs, allocate_gpu_memory); + delete this->packed_pbs_rhs; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->packed_pbs_out, allocate_gpu_memory); + delete this->packed_pbs_out; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->packed_flush_in, allocate_gpu_memory); + delete this->packed_flush_in; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->packed_flush_out, allocate_gpu_memory); + delete this->packed_flush_out; + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +template struct int_trivium_buffer { + int_radix_params params; + bool allocate_gpu_memory; + uint32_t num_inputs; + + int_trivium_lut_buffers *luts; + int_trivium_state_workspaces *state; + + int_trivium_buffer(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_inputs, + uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->num_inputs = num_inputs; + + this->luts = new int_trivium_lut_buffers( + streams, params, allocate_gpu_memory, num_inputs, size_tracker); + + this->state = new int_trivium_state_workspaces( + streams, params, allocate_gpu_memory, num_inputs, size_tracker); + } + + void release(CudaStreams streams) { + luts->release(streams); + delete luts; + luts = nullptr; + + state->release(streams, allocate_gpu_memory); + delete state; + state = nullptr; + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu b/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu new file mode 100644 index 000000000..39efb35f8 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu @@ -0,0 +1,45 @@ +#include "../../include/trivium/trivium.h" +#include "trivium.cuh" + +uint64_t scratch_cuda_trivium_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_trivium_encrypt( + CudaStreams(streams), (int_trivium_buffer **)mem_ptr, params, + allocate_gpu_memory, num_inputs); +} + +void cuda_trivium_generate_keystream_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output, + const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv, + uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + + auto buffer = (int_trivium_buffer *)mem_ptr; + + host_trivium_generate_keystream( + CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps, + buffer, bsks, (uint64_t *const *)ksks); +} + +void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) { + + int_trivium_buffer *mem_ptr = + (int_trivium_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh b/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh new file mode 100644 index 000000000..8b573101b --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh @@ -0,0 +1,311 @@ +#ifndef TRIVIUM_CUH +#define TRIVIUM_CUH + +#include "../../include/trivium/trivium_utilities.h" +#include "../integer/integer.cuh" +#include "../integer/radix_ciphertext.cuh" +#include "../integer/scalar_addition.cuh" +#include "../linearalgebra/addition.cuh" + +template +void reverse_bitsliced_radix_inplace(CudaStreams streams, + int_trivium_buffer *mem, + CudaRadixCiphertextFFI *radix, + uint32_t num_bits_in_reg) { + uint32_t N = mem->num_inputs; + CudaRadixCiphertextFFI *temp = mem->state->shift_workspace; + + for (uint32_t i = 0; i < num_bits_in_reg; i++) { + uint32_t src_start = i * N; + uint32_t src_end = (i + 1) * N; + + uint32_t dest_start = (num_bits_in_reg - 1 - i) * N; + uint32_t dest_end = (num_bits_in_reg - i) * N; + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end, + radix, src_start, src_end); + } + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N, + temp, 0, num_bits_in_reg * N); +} + +template +__host__ __forceinline__ void +trivium_xor(CudaStreams streams, int_trivium_buffer *mem, + CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs, + const CudaRadixCiphertextFFI *rhs) { + host_addition(streams.stream(0), streams.gpu_index(0), out, lhs, rhs, + out->num_radix_blocks, mem->params.message_modulus, + mem->params.carry_modulus); +} + +template +__host__ __forceinline__ void +trivium_flush(CudaStreams streams, int_trivium_buffer *mem, + CudaRadixCiphertextFFI *target, void *const *bsks, + KSTorus *const *ksks) { + integer_radix_apply_univariate_lookup_table( + streams, target, target, bsks, ksks, mem->luts->flush_lut, + target->num_radix_blocks); +} + +template +__host__ void slice_reg_batch(CudaRadixCiphertextFFI *slice, + const CudaRadixCiphertextFFI *reg, + uint32_t start_bit_idx, uint32_t num_bits, + uint32_t num_inputs) { + as_radix_ciphertext_slice(slice, reg, start_bit_idx * num_inputs, + (start_bit_idx + num_bits) * num_inputs); +} + +template +__host__ void shift_and_insert_batch(CudaStreams streams, + int_trivium_buffer *mem, + CudaRadixCiphertextFFI *reg, + CudaRadixCiphertextFFI *new_bits, + uint32_t reg_size, uint32_t num_inputs) { + + constexpr uint32_t BATCH = 64; + CudaRadixCiphertextFFI *temp = mem->state->shift_workspace; + + uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs; + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, 0, BATCH * num_inputs, + new_bits, 0, BATCH * num_inputs); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, BATCH * num_inputs, + reg_size * num_inputs, reg, 0, num_blocks_to_keep); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs, + temp, 0, reg_size * num_inputs); +} + +template +__host__ void +trivium_compute_64_steps(CudaStreams streams, int_trivium_buffer *mem, + CudaRadixCiphertextFFI *output_dest, void *const *bsks, + KSTorus *const *ksks) { + + uint32_t N = mem->num_inputs; + constexpr uint32_t BATCH = 64; + uint32_t batch_size_blocks = BATCH * N; + auto s = mem->state; + + CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice; + slice_reg_batch(&a65_slice, s->a_reg, 2, BATCH, N); + slice_reg_batch(&a92_slice, s->a_reg, 29, BATCH, N); + slice_reg_batch(&a91_slice, s->a_reg, 28, BATCH, N); + slice_reg_batch(&a90_slice, s->a_reg, 27, BATCH, N); + slice_reg_batch(&a68_slice, s->a_reg, 5, BATCH, N); + + CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice; + slice_reg_batch(&b68_slice, s->b_reg, 5, BATCH, N); + slice_reg_batch(&b83_slice, s->b_reg, 20, BATCH, N); + slice_reg_batch(&b82_slice, s->b_reg, 19, BATCH, N); + slice_reg_batch(&b81_slice, s->b_reg, 18, BATCH, N); + slice_reg_batch(&b77_slice, s->b_reg, 14, BATCH, N); + + CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice, + c86_slice; + slice_reg_batch(&c65_slice, s->c_reg, 2, BATCH, N); + slice_reg_batch(&c110_slice, s->c_reg, 47, BATCH, N); + slice_reg_batch(&c109_slice, s->c_reg, 46, BATCH, N); + slice_reg_batch(&c108_slice, s->c_reg, 45, BATCH, N); + slice_reg_batch(&c86_slice, s->c_reg, 23, BATCH, N); + + trivium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice); + trivium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice); + trivium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0, + batch_size_blocks, &c109_slice, 0, batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, + batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0, + batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, + 2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0, + batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0, + batch_size_blocks, &c108_slice, 0, batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, + batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0, + batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, + 2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0, + batch_size_blocks); + + integer_radix_apply_bivariate_lookup_table( + streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks, + ksks, mem->luts->and_lut, 3 * batch_size_blocks, + mem->params.message_modulus); + + CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c; + as_radix_ciphertext_slice(&and_res_a, s->packed_pbs_out, 0, + batch_size_blocks); + as_radix_ciphertext_slice(&and_res_b, s->packed_pbs_out, + batch_size_blocks, 2 * batch_size_blocks); + as_radix_ciphertext_slice(&and_res_c, s->packed_pbs_out, + 2 * batch_size_blocks, + 3 * batch_size_blocks); + + trivium_xor(streams, mem, s->new_a, s->temp_t3, &a68_slice); + trivium_xor(streams, mem, s->new_a, s->new_a, &and_res_a); + + trivium_xor(streams, mem, s->new_b, s->temp_t1, &b77_slice); + trivium_xor(streams, mem, s->new_b, s->new_b, &and_res_b); + + trivium_xor(streams, mem, s->new_c, s->temp_t2, &c86_slice); + trivium_xor(streams, mem, s->new_c, s->new_c, &and_res_c); + + if (output_dest != nullptr) { + trivium_xor(streams, mem, output_dest, s->temp_t1, s->temp_t2); + trivium_xor(streams, mem, output_dest, output_dest, s->temp_t3); + } + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0, + batch_size_blocks, s->new_a, 0, batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + 2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0, + batch_size_blocks); + + uint32_t total_flush_blocks = 3 * batch_size_blocks; + + if (output_dest != nullptr) { + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + 3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0, + batch_size_blocks); + total_flush_blocks += batch_size_blocks; + } + + integer_radix_apply_univariate_lookup_table( + streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks, + mem->luts->flush_lut, total_flush_blocks); + + CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c; + as_radix_ciphertext_slice(&flushed_a, s->packed_flush_out, 0, + batch_size_blocks); + as_radix_ciphertext_slice(&flushed_b, s->packed_flush_out, + batch_size_blocks, 2 * batch_size_blocks); + as_radix_ciphertext_slice(&flushed_c, s->packed_flush_out, + 2 * batch_size_blocks, + 3 * batch_size_blocks); + + shift_and_insert_batch(streams, mem, s->a_reg, &flushed_a, 93, N); + shift_and_insert_batch(streams, mem, s->b_reg, &flushed_b, 84, N); + shift_and_insert_batch(streams, mem, s->c_reg, &flushed_c, 111, N); + + if (output_dest != nullptr) { + CudaRadixCiphertextFFI flushed_out; + as_radix_ciphertext_slice(&flushed_out, s->packed_flush_out, + 3 * batch_size_blocks, + 4 * batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), output_dest, 0, + batch_size_blocks, &flushed_out, 0, batch_size_blocks); + + reverse_bitsliced_radix_inplace(streams, mem, output_dest, 64); + } +} + +template +__host__ void trivium_init(CudaStreams streams, int_trivium_buffer *mem, + CudaRadixCiphertextFFI const *key_bitsliced, + CudaRadixCiphertextFFI const *iv_bitsliced, + void *const *bsks, KSTorus *const *ksks) { + uint32_t N = mem->num_inputs; + auto s = mem->state; + + CudaRadixCiphertextFFI src_key_slice; + slice_reg_batch(&src_key_slice, key_bitsliced, 0, 80, N); + + CudaRadixCiphertextFFI dest_a_slice; + slice_reg_batch(&dest_a_slice, s->a_reg, 0, 80, N); + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_a_slice, &src_key_slice); + + reverse_bitsliced_radix_inplace(streams, mem, s->a_reg, 80); + + CudaRadixCiphertextFFI src_iv_slice; + slice_reg_batch(&src_iv_slice, iv_bitsliced, 0, 80, N); + + CudaRadixCiphertextFFI dest_b_slice; + slice_reg_batch(&dest_b_slice, s->b_reg, 0, 80, N); + + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_b_slice, &src_iv_slice); + + reverse_bitsliced_radix_inplace(streams, mem, s->b_reg, 80); + + CudaRadixCiphertextFFI dest_c_ones; + slice_reg_batch(&dest_c_ones, s->c_reg, 108, 3, N); + + host_add_scalar_one_inplace(streams, &dest_c_ones, + mem->params.message_modulus, + mem->params.carry_modulus); + trivium_flush(streams, mem, &dest_c_ones, bsks, ksks); + + for (int i = 0; i < 18; i++) { + trivium_compute_64_steps(streams, mem, nullptr, bsks, ksks); + } +} + +template +__host__ void host_trivium_generate_keystream( + CudaStreams streams, CudaRadixCiphertextFFI *keystream_output, + CudaRadixCiphertextFFI const *key_bitsliced, + CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs, + uint32_t num_steps, int_trivium_buffer *mem, void *const *bsks, + KSTorus *const *ksks) { + + trivium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks); + + uint32_t num_batches = num_steps / 64; + for (uint32_t i = 0; i < num_batches; i++) { + CudaRadixCiphertextFFI batch_out_slice; + slice_reg_batch(&batch_out_slice, keystream_output, i * 64, 64, + num_inputs); + trivium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks); + } +} + +template +uint64_t scratch_cuda_trivium_encrypt(CudaStreams streams, + int_trivium_buffer **mem_ptr, + int_radix_params params, + bool allocate_gpu_memory, + uint32_t num_inputs) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_trivium_buffer(streams, params, allocate_gpu_memory, + num_inputs, size_tracker); + return size_tracker; +} + +#endif diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index d58512a03..f858a9d2e 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -2495,6 +2495,42 @@ unsafe extern "C" { mem_ptr_void: *mut *mut i8, ); } +unsafe extern "C" { + pub fn scratch_cuda_trivium_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + num_inputs: u32, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_trivium_generate_keystream_64( + streams: CudaStreamsFFI, + keystream_output: *mut CudaRadixCiphertextFFI, + key: *const CudaRadixCiphertextFFI, + iv: *const CudaRadixCiphertextFFI, + num_inputs: u32, + num_steps: u32, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); +} pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0; pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1; pub type KS_TYPE = ffi::c_uint; diff --git a/backends/tfhe-cuda-backend/wrapper.h b/backends/tfhe-cuda-backend/wrapper.h index da1347951..d4f6f9fc8 100644 --- a/backends/tfhe-cuda-backend/wrapper.h +++ b/backends/tfhe-cuda-backend/wrapper.h @@ -4,6 +4,7 @@ #include "cuda/include/integer/integer.h" #include "cuda/include/integer/rerand.h" #include "cuda/include/aes/aes.h" +#include "cuda/include/trivium/trivium.h" #include "cuda/include/zk/zk.h" #include "cuda/include/keyswitch/keyswitch.h" #include "cuda/include/keyswitch/ks_enums.h" diff --git a/tfhe-benchmark/Cargo.toml b/tfhe-benchmark/Cargo.toml index 1fb042d09..4f04edbe3 100644 --- a/tfhe-benchmark/Cargo.toml +++ b/tfhe-benchmark/Cargo.toml @@ -133,6 +133,12 @@ path = "benches/integer/aes.rs" harness = false required-features = ["integer", "internal-keycache"] +[[bench]] +name = "integer-trivium" +path = "benches/integer/trivium.rs" +harness = false +required-features = ["integer", "internal-keycache"] + [[bench]] name = "integer-aes256" path = "benches/integer/aes256.rs" diff --git a/tfhe-benchmark/benches/integer/bench.rs b/tfhe-benchmark/benches/integer/bench.rs index d48d58bd5..fa12b5b5e 100644 --- a/tfhe-benchmark/benches/integer/bench.rs +++ b/tfhe-benchmark/benches/integer/bench.rs @@ -3,6 +3,7 @@ mod aes; mod aes256; mod oprf; +mod trivium; mod vector_find; mod rerand; diff --git a/tfhe-benchmark/benches/integer/trivium.rs b/tfhe-benchmark/benches/integer/trivium.rs new file mode 100644 index 000000000..02e359966 --- /dev/null +++ b/tfhe-benchmark/benches/integer/trivium.rs @@ -0,0 +1,100 @@ +use criterion::Criterion; + +#[cfg(feature = "gpu")] +pub mod cuda { + use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + use benchmark::utilities::{write_to_json, OperatorType}; + use criterion::{black_box, criterion_group, Criterion}; + use tfhe::core_crypto::gpu::CudaStreams; + use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + use tfhe::integer::gpu::CudaServerKey; + use tfhe::integer::keycache::KEY_CACHE; + use tfhe::integer::{IntegerKeyKind, RadixClientKey}; + use tfhe::keycache::NamedParam; + use tfhe::shortint::AtomicPatternParameters; + + pub fn cuda_trivium(c: &mut Criterion) { + let bench_name = "integer::cuda::trivium"; + + let mut bench_group = c.benchmark_group(bench_name); + bench_group + .sample_size(15) + .measurement_time(std::time::Duration::from_secs(60)) + .warm_up_time(std::time::Duration::from_secs(5)); + + let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let atomic_param: AtomicPatternParameters = param.into(); + + let key_bits = vec![0u64; 80]; + let iv_bits = vec![0u64; 80]; + + let param_name = param.name(); + + let streams = CudaStreams::new_multi_gpu(); + let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix); + let sks = CudaServerKey::new(&cpu_cks, &streams); + let cks = RadixClientKey::from((cpu_cks, 1)); + + let ct_key = cks.encrypt_bits_for_trivium(&key_bits); + let ct_iv = cks.encrypt_bits_for_trivium(&iv_bits); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); + + { + let num_steps = 64; + let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits"); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams)); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "trivium_generation_64_bits", + &OperatorType::Atomic, + 80, + vec![atomic_param.message_modulus().0.ilog2(); 80], + ); + } + + { + let num_steps = 512; + let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits"); + + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams)); + }) + }); + + write_to_json::( + &bench_id, + atomic_param, + param.name(), + "trivium_generation_512_bits", + &OperatorType::Atomic, + 80, + vec![atomic_param.message_modulus().0.ilog2(); 80], + ); + } + + bench_group.finish(); + } + + criterion_group!(gpu_trivium, cuda_trivium); +} + +#[cfg(feature = "gpu")] +use cuda::gpu_trivium; + +fn main() { + #[cfg(feature = "gpu")] + gpu_trivium(); + + Criterion::default().configure_from_args().final_summary(); +} diff --git a/tfhe-benchmark/src/utilities.rs b/tfhe-benchmark/src/utilities.rs index fc16e098c..47a351d2a 100644 --- a/tfhe-benchmark/src/utilities.rs +++ b/tfhe-benchmark/src/utilities.rs @@ -5,8 +5,6 @@ use std::{env, fs}; #[cfg(feature = "gpu")] use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms}; use tfhe::core_crypto::prelude::*; -#[cfg(feature = "integer")] -use tfhe::prelude::*; #[cfg(feature = "boolean")] pub mod boolean_utils { diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 14bdd6508..3292412aa 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -10423,3 +10423,98 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async( carry_modulus.0 as u32, ); } + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_trivium_generate_keystream( + streams: &CudaStreams, + keystream_output: &mut CudaRadixCiphertext, + key: &CudaRadixCiphertext, + iv: &CudaRadixCiphertext, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, + num_steps: u32, +) { + let mut keystream_degrees = keystream_output + .info + .blocks + .iter() + .map(|b| b.degree.0) + .collect(); + let mut keystream_noise_levels = keystream_output + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut cuda_ffi_keystream = prepare_cuda_radix_ffi( + keystream_output, + &mut keystream_degrees, + &mut keystream_noise_levels, + ); + + let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels); + + let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels); + + let num_inputs = (key.info.blocks.len() / 80) as u32; + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + scratch_cuda_trivium_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + num_inputs, + ); + + cuda_trivium_generate_keystream_64( + streams.ffi(), + &raw mut cuda_ffi_keystream, + &raw const cuda_ffi_key, + &raw const cuda_ffi_iv, + num_inputs, + num_steps, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_trivium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(keystream_output, &cuda_ffi_keystream); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index c83af97cf..2ce07d824 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -66,6 +66,7 @@ mod tests_noise_distribution; mod tests_signed; #[cfg(test)] mod tests_unsigned; +mod trivium; impl CudaServerKey { /// Create a trivial ciphertext filled with zeros on the GPU. diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index ed7fbe931..f2336f766 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -20,6 +20,7 @@ pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; pub(crate) mod test_sub; +pub(crate) mod test_trivium; pub(crate) mod test_vector_comparisons; pub(crate) mod test_vector_find; @@ -85,6 +86,47 @@ impl GpuFunctionExecutor { } } +impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &CudaUnsignedRadixCiphertext, + usize, + &CudaStreams, + ) -> CudaUnsignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a RadixCiphertext, usize), + ) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + let d_ctxt_2 = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams); + + let gpu_result = (self.func)( + &context.sks, + &d_ctxt_1, + &d_ctxt_2, + input.2, + &context.streams, + ); + + gpu_result.to_radix_ciphertext(&context.streams) + } +} + impl<'a, F> FunctionExecutor< (&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize), diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_trivium.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_trivium.rs new file mode 100644 index 000000000..a5021f845 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_trivium.rs @@ -0,0 +1,71 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parameterized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_unsigned::test_trivium::{ + trivium_comparison_test, trivium_test_vector_1_test, trivium_test_vector_2_test, + trivium_test_vector_3_test, trivium_test_vector_4_test, +}; +use crate::shortint::parameters::{ + TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, +}; + +create_gpu_parameterized_test!(integer_trivium_test_vector_1 { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_trivium_test_vector_2 { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_trivium_test_vector_3 { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_trivium_test_vector_4 { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +create_gpu_parameterized_test!(integer_trivium_comparison { + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 +}); + +fn integer_trivium_test_vector_1

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream); + trivium_test_vector_1_test(param, executor); +} + +fn integer_trivium_test_vector_2

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream); + trivium_test_vector_2_test(param, executor); +} + +fn integer_trivium_test_vector_3

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream); + trivium_test_vector_3_test(param, executor); +} + +fn integer_trivium_test_vector_4

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream); + trivium_test_vector_4_test(param, executor); +} + +fn integer_trivium_comparison

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream); + trivium_comparison_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/trivium.rs b/tfhe/src/integer/gpu/server_key/radix/trivium.rs new file mode 100644 index 000000000..5fdfae2e8 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/trivium.rs @@ -0,0 +1,176 @@ +use crate::core_crypto::gpu::CudaStreams; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::server_key::{ + CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey, +}; +use crate::integer::gpu::{cuda_backend_trivium_generate_keystream, LweBskGroupingFactor, PBSType}; +use crate::integer::{RadixCiphertext, RadixClientKey}; +use crate::shortint::Ciphertext; + +impl RadixClientKey { + /// Encrypts a stream of bits for homomorphic stream cipher evaluation (like Trivium). + /// + /// This function prepares a vector of bits (represented as u64s, e.g., keys or IVs) + /// for homomorphic processing by encrypting each bit individually into a single LWE block. + /// + /// The process is as follows: + /// ```text + /// // INPUT: A slice of bits (0 or 1) + /// Input bits: [1, 0, 1, 1, ...] + /// | + /// V + /// // 1. Iterate over each bit + /// | + /// V + /// // 2. Encrypt each bit individually + /// `self.encrypt(bit)` creates a ciphertext. + /// We extract the single LWE block representing this bit. + /// | + /// V + /// // 3. Collect the resulting LWE blocks + /// Blocks: [LWE(1), LWE(0), LWE(1), LWE(1), ...] + /// | + /// V + /// // 4. Group blocks into a single RadixCiphertext container + /// // OUTPUT: A RadixCiphertext where blocks[i] encrypts input[i] + /// ``` + pub fn encrypt_bits_for_trivium(&self, bits: &[u64]) -> RadixCiphertext { + let mut blocks: Vec = Vec::with_capacity(bits.len()); + for &bit in bits { + let mut ct = self.encrypt(bit); + let block = ct.blocks.pop().unwrap(); + blocks.push(block); + } + RadixCiphertext::from(blocks) + } + + /// Decrypts a `RadixCiphertext` containing a stream of encrypted bits + /// (e.g. the output keystream of Trivium). + /// + /// This function reverses the encryption process by treating each block of the + /// `RadixCiphertext` as an independent bit, decrypting it, and collecting the results. + /// + /// The process is as follows: + /// ```text + /// // INPUT: RadixCiphertext containing N encrypted bits + /// Ciphertext blocks: [Block_0, Block_1, ..., Block_N] + /// | + /// V + /// // 1. Iterate over each block + /// | + /// V + /// // 2. Decrypt each block individually + /// Treat Block_i as a standalone RadixCiphertext -> decrypt -> u64 + /// | + /// V + /// // 3. Collect the plaintext bits + /// Plaintext bits: [1, 0, 1, 1, ...] + /// | + /// V + /// // OUTPUT: A vector of bits (u8) + /// ``` + pub fn decrypt_bits_from_trivium(&self, encrypted_stream: &RadixCiphertext) -> Vec { + let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len()); + for block in &encrypted_stream.blocks { + let tmp_radix = RadixCiphertext::from(vec![block.clone()]); + let val: u64 = self.decrypt(&tmp_radix); + decrypted_bits.push(val as u8); + } + decrypted_bits + } +} + +impl CudaServerKey { + /// Generates a Trivium keystream homomorphically on the GPU. + /// + /// # Arguments + /// * `key` - The encrypted secret key (80 bits). + /// * `iv` - The encrypted initialization vector (80 bits). + /// * `num_steps` - The number of keystream bits to generate per input. + /// * `streams` - The CUDA streams to use for execution. + pub fn trivium_generate_keystream( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + num_steps: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let num_key_bits = 80; + let num_iv_bits = 80; + + assert_eq!( + key.as_ref().d_blocks.lwe_ciphertext_count().0, + num_key_bits, + "Input key must contain {} encrypted bits, but contains {}", + num_key_bits, + key.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + iv.as_ref().d_blocks.lwe_ciphertext_count().0, + num_iv_bits, + "Input IV must contain {} encrypted bits, but contains {}", + num_iv_bits, + iv.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + let num_output_bits = num_steps; + let mut keystream: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_output_bits, streams); + + let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { + panic!("Only the standard atomic pattern is supported on GPU") + }; + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_trivium_generate_keystream( + streams, + keystream.as_mut(), + key.as_ref(), + iv.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + num_steps as u32, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_trivium_generate_keystream( + streams, + keystream.as_mut(), + key.as_ref(), + iv.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + num_steps as u32, + ); + } + } + } + keystream + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 176686baa..fe6a24362 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -28,6 +28,7 @@ pub(crate) mod test_shift; pub(crate) mod test_slice; pub(crate) mod test_sub; pub(crate) mod test_sum; +pub(crate) mod test_trivium; pub(crate) mod test_vector_comparisons; pub(crate) mod test_vector_find; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_trivium.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_trivium.rs new file mode 100644 index 000000000..12f19cf3c --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_trivium.rs @@ -0,0 +1,260 @@ +#![cfg(feature = "gpu")] + +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey}; +use crate::shortint::parameters::TestParameters; +use rand::Rng; +use std::sync::Arc; + +struct TriviumRef { + a: Vec, + b: Vec, + c: Vec, +} + +impl TriviumRef { + fn new(key: &[u8], iv: &[u8]) -> Self { + let mut a = vec![0u8; 93]; + let mut b = vec![0u8; 84]; + let mut c = vec![0u8; 111]; + + for i in 0..80 { + a[i] = key[79 - i]; + b[i] = iv[79 - i]; + } + + c[108] = 1; + c[109] = 1; + c[110] = 1; + + let mut triv = Self { a, b, c }; + for _ in 0..(18 * 64) { + triv.next(); + } + triv + } + + fn next(&mut self) -> u8 { + let t1 = self.a[65] ^ self.a[92]; + let t2 = self.b[68] ^ self.b[83]; + let t3 = self.c[65] ^ self.c[110]; + + let out = t1 ^ t2 ^ t3; + + let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]); + let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]); + let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]); + + self.a.pop(); + self.a.insert(0, a_in); + self.b.pop(); + self.b.insert(0, b_in); + self.c.pop(); + self.c.insert(0, c_in); + + out + } +} + +fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String { + assert!(a.len().is_multiple_of(8)); + let mut hexadecimal = String::new(); + for test in a.chunks(8) { + let to_hex = |chunk: &[u8]| -> char { + let mut val = 0u8; + if chunk[0] == 1 { + val |= 1; + } + if chunk[1] == 1 { + val |= 2; + } + if chunk[2] == 1 { + val |= 4; + } + if chunk[3] == 1 { + val |= 8; + } + + match val { + 0..=9 => (val + b'0') as char, + 10..=15 => (val - 10 + b'A') as char, + _ => unreachable!(), + } + }; + + hexadecimal.push(to_hex(&test[4..8])); + hexadecimal.push(to_hex(&test[0..4])); + } + hexadecimal +} + +pub fn trivium_test_vector_1_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key = vec![0u64; 80]; + let iv = vec![0u64; 80]; + + let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2"; + + let cpu_key = cks.encrypt_bits_for_trivium(&key); + let cpu_iv = cks.encrypt_bits_for_trivium(&iv); + + let num_steps = 512; + let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps)); + + let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix); + let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits); + + assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]); +} + +pub fn trivium_test_vector_2_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let mut key = vec![0u64; 80]; + let iv = vec![0u64; 80]; + key[7] = 1; + + let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9"; + + let cpu_key = cks.encrypt_bits_for_trivium(&key); + let cpu_iv = cks.encrypt_bits_for_trivium(&iv); + + let num_steps = 512; + let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps)); + + let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix); + let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits); + + assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]); +} + +pub fn trivium_test_vector_3_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key = vec![0u64; 80]; + let mut iv = vec![0u64; 80]; + iv[7] = 1; + + let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7"; + + let cpu_key = cks.encrypt_bits_for_trivium(&key); + let cpu_iv = cks.encrypt_bits_for_trivium(&iv); + + let num_steps = 512; + let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps)); + + let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix); + let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits); + + assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]); +} + +pub fn trivium_test_vector_4_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let key_string = "0053A6F94C9FF24598EB"; + let mut key = vec![0u64; 80]; + for i in (0..key_string.len()).step_by(2) { + let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap(); + for j in 0..8 { + key[8 * (i >> 1) + j] = (val % 2) as u64; + val >>= 1; + } + } + + let iv_string = "0D74DB42A91077DE45AC"; + let mut iv = vec![0u64; 80]; + for i in (0..iv_string.len()).step_by(2) { + let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap(); + for j in 0..8 { + iv[8 * (i >> 1) + j] = (val % 2) as u64; + val >>= 1; + } + } + + let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358"; + + let cpu_key = cks.encrypt_bits_for_trivium(&key); + let cpu_iv = cks.encrypt_bits_for_trivium(&iv); + + let num_steps = 512; + let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps)); + + let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix); + let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits); + + assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]); +} + +pub fn trivium_comparison_test(param: P, mut executor: E) +where + P: Into, + E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>, +{ + let param = param.into(); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 1)); + let sks = Arc::new(sks); + executor.setup(&cks, sks); + + let num_runs = 5; + let num_steps = 512; + + for i in 0..num_runs { + let mut rng = rand::thread_rng(); + let plain_key: Vec = (0..80).map(|_| rng.gen_range(0..=1)).collect(); + let plain_iv: Vec = (0..80).map(|_| rng.gen_range(0..=1)).collect(); + + let key_bits_u64: Vec = plain_key.iter().map(|&x| x as u64).collect(); + let iv_bits_u64: Vec = plain_iv.iter().map(|&x| x as u64).collect(); + + let cpu_key = cks.encrypt_bits_for_trivium(&key_bits_u64); + let cpu_iv = cks.encrypt_bits_for_trivium(&iv_bits_u64); + + let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv); + let mut cpu_output = Vec::with_capacity(num_steps); + for _ in 0..num_steps { + cpu_output.push(cpu_trivium.next()); + } + + let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps)); + let fhe_output = cks.decrypt_bits_from_trivium(&output_radix); + + assert_eq!(cpu_output.len(), fhe_output.len()); + assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}"); + } +}