diff --git a/backends/tfhe-cuda-backend/build.rs b/backends/tfhe-cuda-backend/build.rs index 00676c109..5af46f36c 100644 --- a/backends/tfhe-cuda-backend/build.rs +++ b/backends/tfhe-cuda-backend/build.rs @@ -87,6 +87,7 @@ fn main() { "cuda/include/integer/rerand.h", "cuda/include/aes/aes.h", "cuda/include/trivium/trivium.h", + "cuda/include/kreyvium/kreyvium.h", "cuda/include/zk/zk.h", "cuda/include/keyswitch/keyswitch.h", "cuda/include/keyswitch/ks_enums.h", diff --git a/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h b/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h new file mode 100644 index 000000000..6edcac328 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h @@ -0,0 +1,24 @@ +#ifndef KREYVIUM_H +#define KREYVIUM_H + +#include "../integer/integer.h" + +extern "C" { +uint64_t scratch_cuda_kreyvium_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs); + +void cuda_kreyvium_generate_keystream_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output, + const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv, + uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); +} + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium_utilities.h b/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium_utilities.h new file mode 100644 index 000000000..6190727f7 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium_utilities.h @@ -0,0 +1,256 @@ +#ifndef KREYVIUM_UTILITIES_H +#define KREYVIUM_UTILITIES_H +#include "../integer/integer_utilities.h" + +template struct int_kreyvium_lut_buffers { + int_radix_lut *and_lut; + int_radix_lut *flush_lut; + + int_kreyvium_lut_buffers(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_inputs, + uint64_t &size_tracker) { + + constexpr uint32_t BATCH_SIZE = 64; + constexpr uint32_t MAX_AND_PER_STEP = 3; + uint32_t total_lut_ops = num_inputs * BATCH_SIZE * MAX_AND_PER_STEP; + + this->and_lut = new int_radix_lut(streams, params, 1, total_lut_ops, + allocate_gpu_memory, size_tracker); + + std::function and_lambda = + [](Torus a, Torus b) -> Torus { return a & b; }; + + generate_device_accumulator_bivariate( + streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0), + this->and_lut->get_degree(0), this->and_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, and_lambda, allocate_gpu_memory); + + auto active_streams_and = + streams.active_gpu_subset(total_lut_ops, params.pbs_type); + this->and_lut->broadcast_lut(active_streams_and); + this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker); + + uint32_t total_flush_ops = num_inputs * BATCH_SIZE * 4; + + this->flush_lut = new int_radix_lut( + streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker); + + std::function flush_lambda = [](Torus x) -> Torus { + return x & 1; + }; + + generate_device_accumulator( + streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0), + this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, flush_lambda, allocate_gpu_memory); + + auto active_streams_flush = + streams.active_gpu_subset(total_flush_ops, params.pbs_type); + this->flush_lut->broadcast_lut(active_streams_flush); + this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker); + } + + void release(CudaStreams streams) { + this->and_lut->release(streams); + delete this->and_lut; + this->and_lut = nullptr; + + this->flush_lut->release(streams); + delete this->flush_lut; + this->flush_lut = nullptr; + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +template struct int_kreyvium_state_workspaces { + CudaRadixCiphertextFFI *a_reg; + CudaRadixCiphertextFFI *b_reg; + CudaRadixCiphertextFFI *c_reg; + + CudaRadixCiphertextFFI *k_reg; + CudaRadixCiphertextFFI *iv_reg; + + CudaRadixCiphertextFFI *shift_workspace; + + CudaRadixCiphertextFFI *temp_t1; + CudaRadixCiphertextFFI *temp_t2; + CudaRadixCiphertextFFI *temp_t3; + CudaRadixCiphertextFFI *new_a; + CudaRadixCiphertextFFI *new_b; + CudaRadixCiphertextFFI *new_c; + + CudaRadixCiphertextFFI *packed_pbs_lhs; + CudaRadixCiphertextFFI *packed_pbs_rhs; + CudaRadixCiphertextFFI *packed_pbs_out; + + CudaRadixCiphertextFFI *packed_flush_in; + CudaRadixCiphertextFFI *packed_flush_out; + + int_kreyvium_state_workspaces(CudaStreams streams, + const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_inputs, + uint64_t &size_tracker) { + + this->a_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->b_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->c_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->k_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->k_reg, 128 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->iv_reg = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->iv_reg, 128 * num_inputs, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->shift_workspace = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->shift_workspace, + 128 * num_inputs, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + uint32_t batch_blocks = 64 * num_inputs; + + this->temp_t1 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->temp_t2 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->temp_t3 = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_a = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_b = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->new_c = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + this->packed_pbs_lhs = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_pbs_rhs = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_pbs_out = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_pbs_out, + 3 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_flush_in = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_flush_in, + 4 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + + this->packed_flush_out = new CudaRadixCiphertextFFI; + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->packed_flush_out, + 4 * batch_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + } + + void release(CudaStreams streams, bool allocate_gpu_memory) { + auto release_and_delete = [&](CudaRadixCiphertextFFI *&ptr) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + ptr, allocate_gpu_memory); + delete ptr; + ptr = nullptr; + }; + + release_and_delete(this->a_reg); + release_and_delete(this->b_reg); + release_and_delete(this->c_reg); + release_and_delete(this->k_reg); + release_and_delete(this->iv_reg); + release_and_delete(this->shift_workspace); + release_and_delete(this->temp_t1); + release_and_delete(this->temp_t2); + release_and_delete(this->temp_t3); + release_and_delete(this->new_a); + release_and_delete(this->new_b); + release_and_delete(this->new_c); + release_and_delete(this->packed_pbs_lhs); + release_and_delete(this->packed_pbs_rhs); + release_and_delete(this->packed_pbs_out); + release_and_delete(this->packed_flush_in); + release_and_delete(this->packed_flush_out); + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +template struct int_kreyvium_buffer { + int_radix_params params; + bool allocate_gpu_memory; + uint32_t num_inputs; + + int_kreyvium_lut_buffers *luts; + int_kreyvium_state_workspaces *state; + + int_kreyvium_buffer(CudaStreams streams, const int_radix_params ¶ms, + bool allocate_gpu_memory, uint32_t num_inputs, + uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->num_inputs = num_inputs; + + this->luts = new int_kreyvium_lut_buffers( + streams, params, allocate_gpu_memory, num_inputs, size_tracker); + + this->state = new int_kreyvium_state_workspaces( + streams, params, allocate_gpu_memory, num_inputs, size_tracker); + } + + void release(CudaStreams streams) { + luts->release(streams); + delete luts; + luts = nullptr; + + state->release(streams, allocate_gpu_memory); + delete state; + state = nullptr; + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +#endif diff --git a/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu b/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu new file mode 100644 index 000000000..283339645 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu @@ -0,0 +1,45 @@ +#include "../../include/kreyvium/kreyvium.h" +#include "kreyvium.cuh" + +uint64_t scratch_cuda_kreyvium_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_kreyvium_encrypt( + CudaStreams(streams), (int_kreyvium_buffer **)mem_ptr, params, + allocate_gpu_memory, num_inputs); +} + +void cuda_kreyvium_generate_keystream_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output, + const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv, + uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + + auto buffer = (int_kreyvium_buffer *)mem_ptr; + + host_kreyvium_generate_keystream( + CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps, + buffer, bsks, (uint64_t *const *)ksks); +} + +void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) { + + int_kreyvium_buffer *mem_ptr = + (int_kreyvium_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh b/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh new file mode 100644 index 000000000..b0238e754 --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh @@ -0,0 +1,376 @@ +#ifndef KREYVIUM_CUH +#define KREYVIUM_CUH + +#include "../../include/kreyvium/kreyvium_utilities.h" +#include "../integer/bitwise_ops.cuh" +#include "../integer/integer.cuh" +#include "../integer/radix_ciphertext.cuh" +#include "../integer/scalar_addition.cuh" +#include "../linearalgebra/addition.cuh" + +template +void reverse_bitsliced_radix_inplace_kreyvium(CudaStreams streams, + int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI *radix, + uint32_t num_bits_in_reg) { + uint32_t N = mem->num_inputs; + CudaRadixCiphertextFFI *temp = mem->state->shift_workspace; + + for (uint32_t i = 0; i < num_bits_in_reg; i++) { + uint32_t src_start = i * N; + uint32_t src_end = (i + 1) * N; + + uint32_t dest_start = (num_bits_in_reg - 1 - i) * N; + uint32_t dest_end = (num_bits_in_reg - i) * N; + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end, + radix, src_start, src_end); + } + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N, + temp, 0, num_bits_in_reg * N); +} + +template +__host__ __forceinline__ void +kreyvium_xor(CudaStreams streams, int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs, + const CudaRadixCiphertextFFI *rhs) { + host_addition(streams.stream(0), streams.gpu_index(0), out, lhs, rhs, + out->num_radix_blocks, mem->params.message_modulus, + mem->params.carry_modulus); +} + +template +__host__ void slice_reg_batch_kreyvium(CudaRadixCiphertextFFI *slice, + const CudaRadixCiphertextFFI *reg, + uint32_t start_bit_idx, + uint32_t num_bits, uint32_t num_inputs) { + as_radix_ciphertext_slice(slice, reg, start_bit_idx * num_inputs, + (start_bit_idx + num_bits) * num_inputs); +} + +template +__host__ void rotate_128_batch(CudaStreams streams, + int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI *reg) { + uint32_t N = mem->num_inputs; + uint32_t HALF_SIZE = 64 * N; + CudaRadixCiphertextFFI *temp = mem->state->shift_workspace; + + copy_radix_ciphertext_slice_async(streams.stream(0), + streams.gpu_index(0), temp, 0, + HALF_SIZE, reg, 0, HALF_SIZE); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), reg, 0, HALF_SIZE, reg, + HALF_SIZE, 2 * HALF_SIZE); + + copy_radix_ciphertext_slice_async(streams.stream(0), + streams.gpu_index(0), reg, HALF_SIZE, + 2 * HALF_SIZE, temp, 0, HALF_SIZE); +} + +template +__host__ void shift_and_insert_batch_kreyvium(CudaStreams streams, + int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI *reg, + CudaRadixCiphertextFFI *new_bits, + uint32_t reg_size, + uint32_t num_inputs) { + constexpr uint32_t BATCH = 64; + CudaRadixCiphertextFFI *temp = mem->state->shift_workspace; + uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs; + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, 0, num_blocks_to_keep, reg, + BATCH * num_inputs, reg_size * num_inputs); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), temp, num_blocks_to_keep, + reg_size * num_inputs, new_bits, 0, BATCH * num_inputs); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs, + temp, 0, reg_size * num_inputs); +} + +template +__host__ void +kreyvium_compute_64_steps(CudaStreams streams, int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI *output_dest, + void *const *bsks, uint64_t *const *ksks) { + + uint32_t N = mem->num_inputs; + constexpr uint32_t BATCH = 64; + uint32_t batch_size_blocks = BATCH * N; + auto s = mem->state; + + CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice; + slice_reg_batch_kreyvium(&a65_slice, s->a_reg, 27, BATCH, N); + slice_reg_batch_kreyvium(&a92_slice, s->a_reg, 0, BATCH, N); + slice_reg_batch_kreyvium(&a91_slice, s->a_reg, 1, BATCH, N); + slice_reg_batch_kreyvium(&a90_slice, s->a_reg, 2, BATCH, N); + slice_reg_batch_kreyvium(&a68_slice, s->a_reg, 24, BATCH, N); + + CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice; + slice_reg_batch_kreyvium(&b68_slice, s->b_reg, 15, BATCH, N); + slice_reg_batch_kreyvium(&b83_slice, s->b_reg, 0, BATCH, N); + slice_reg_batch_kreyvium(&b82_slice, s->b_reg, 1, BATCH, N); + slice_reg_batch_kreyvium(&b81_slice, s->b_reg, 2, BATCH, N); + slice_reg_batch_kreyvium(&b77_slice, s->b_reg, 6, BATCH, N); + + CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice, + c86_slice; + slice_reg_batch_kreyvium(&c65_slice, s->c_reg, 45, BATCH, N); + slice_reg_batch_kreyvium(&c110_slice, s->c_reg, 0, BATCH, N); + slice_reg_batch_kreyvium(&c109_slice, s->c_reg, 1, BATCH, N); + slice_reg_batch_kreyvium(&c108_slice, s->c_reg, 2, BATCH, N); + slice_reg_batch_kreyvium(&c86_slice, s->c_reg, 24, BATCH, N); + + CudaRadixCiphertextFFI k_slice, iv_slice; + slice_reg_batch_kreyvium(&k_slice, s->k_reg, 0, BATCH, N); + slice_reg_batch_kreyvium(&iv_slice, s->iv_reg, 0, BATCH, N); + + kreyvium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice); + kreyvium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice); + kreyvium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice); + kreyvium_xor(streams, mem, s->temp_t3, s->temp_t3, &k_slice); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0, + batch_size_blocks, &c109_slice, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, + batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0, + batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, + 2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0, + batch_size_blocks); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0, + batch_size_blocks, &c108_slice, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, + batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0, + batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, + 2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0, + batch_size_blocks); + + integer_radix_apply_bivariate_lookup_table( + streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks, + ksks, mem->luts->and_lut, 3 * batch_size_blocks, + mem->params.message_modulus); + + CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c; + as_radix_ciphertext_slice(&and_res_a, s->packed_pbs_out, 0, + batch_size_blocks); + as_radix_ciphertext_slice(&and_res_b, s->packed_pbs_out, + batch_size_blocks, 2 * batch_size_blocks); + as_radix_ciphertext_slice(&and_res_c, s->packed_pbs_out, + 2 * batch_size_blocks, + 3 * batch_size_blocks); + + kreyvium_xor(streams, mem, s->new_a, &and_res_a, &a68_slice); + kreyvium_xor(streams, mem, s->new_b, &and_res_b, &b77_slice); + kreyvium_xor(streams, mem, s->new_b, s->new_b, s->temp_t1); + kreyvium_xor(streams, mem, s->new_c, &and_res_c, &c86_slice); + kreyvium_xor(streams, mem, s->new_c, s->new_c, s->temp_t2); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0, + batch_size_blocks, s->new_a, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + 2 * batch_size_blocks, 3 * batch_size_blocks, s->temp_t3, 0, + batch_size_blocks); + + integer_radix_apply_univariate_lookup_table( + streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks, + mem->luts->flush_lut, 3 * batch_size_blocks); + + CudaRadixCiphertextFFI flushed_new_a_pre, flushed_new_b_pre, flushed_temp_t3; + as_radix_ciphertext_slice(&flushed_new_a_pre, s->packed_flush_out, 0, + batch_size_blocks); + as_radix_ciphertext_slice(&flushed_new_b_pre, s->packed_flush_out, + batch_size_blocks, 2 * batch_size_blocks); + as_radix_ciphertext_slice(&flushed_temp_t3, s->packed_flush_out, + 2 * batch_size_blocks, + 3 * batch_size_blocks); + + kreyvium_xor(streams, mem, s->new_a, &flushed_new_a_pre, &flushed_temp_t3); + kreyvium_xor(streams, mem, s->new_b, &flushed_new_b_pre, &iv_slice); + + if (output_dest != nullptr) { + host_addition(streams.stream(0), streams.gpu_index(0), output_dest, + s->temp_t1, s->temp_t2, output_dest->num_radix_blocks, + mem->params.message_modulus, + mem->params.carry_modulus); + host_addition( + streams.stream(0), streams.gpu_index(0), output_dest, output_dest, + &flushed_temp_t3, output_dest->num_radix_blocks, + mem->params.message_modulus, mem->params.carry_modulus); + } + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0, + batch_size_blocks, s->new_a, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + 2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0, + batch_size_blocks); + + uint32_t total_flush_blocks = 3 * batch_size_blocks; + + if (output_dest != nullptr) { + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), s->packed_flush_in, + 3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0, + batch_size_blocks); + total_flush_blocks += batch_size_blocks; + } + + integer_radix_apply_univariate_lookup_table( + streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks, + mem->luts->flush_lut, total_flush_blocks); + + CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c; + as_radix_ciphertext_slice(&flushed_a, s->packed_flush_out, 0, + batch_size_blocks); + as_radix_ciphertext_slice(&flushed_b, s->packed_flush_out, + batch_size_blocks, 2 * batch_size_blocks); + as_radix_ciphertext_slice(&flushed_c, s->packed_flush_out, + 2 * batch_size_blocks, + 3 * batch_size_blocks); + + shift_and_insert_batch_kreyvium(streams, mem, s->a_reg, &flushed_a, 93, N); + shift_and_insert_batch_kreyvium(streams, mem, s->b_reg, &flushed_b, 84, N); + shift_and_insert_batch_kreyvium(streams, mem, s->c_reg, &flushed_c, 111, N); + + rotate_128_batch(streams, mem, s->k_reg); + rotate_128_batch(streams, mem, s->iv_reg); + + if (output_dest != nullptr) { + CudaRadixCiphertextFFI flushed_out; + as_radix_ciphertext_slice(&flushed_out, s->packed_flush_out, + 3 * batch_size_blocks, + 4 * batch_size_blocks); + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), output_dest, 0, + batch_size_blocks, &flushed_out, 0, batch_size_blocks); + } +} + +template +__host__ void kreyvium_init(CudaStreams streams, + int_kreyvium_buffer *mem, + CudaRadixCiphertextFFI const *key_bitsliced, + CudaRadixCiphertextFFI const *iv_bitsliced, + void *const *bsks, uint64_t *const *ksks) { + uint32_t N = mem->num_inputs; + auto s = mem->state; + + CudaRadixCiphertextFFI src_key_slice; + slice_reg_batch_kreyvium(&src_key_slice, key_bitsliced, 0, 128, N); + + CudaRadixCiphertextFFI dest_k_reg_slice; + slice_reg_batch_kreyvium(&dest_k_reg_slice, s->k_reg, 0, 128, N); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_k_reg_slice, &src_key_slice); + + CudaRadixCiphertextFFI k_source_for_a; + slice_reg_batch_kreyvium(&k_source_for_a, s->k_reg, 35, 93, N); + + CudaRadixCiphertextFFI dest_a_slice; + slice_reg_batch_kreyvium(&dest_a_slice, s->a_reg, 0, 93, N); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_a_slice, &k_source_for_a); + + reverse_bitsliced_radix_inplace_kreyvium(streams, mem, s->k_reg, 128); + + CudaRadixCiphertextFFI src_iv_slice; + slice_reg_batch_kreyvium(&src_iv_slice, iv_bitsliced, 0, 128, N); + + CudaRadixCiphertextFFI dest_iv_reg_slice; + slice_reg_batch_kreyvium(&dest_iv_reg_slice, s->iv_reg, 0, 128, N); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_iv_reg_slice, &src_iv_slice); + + CudaRadixCiphertextFFI iv_source_for_b; + slice_reg_batch_kreyvium(&iv_source_for_b, s->iv_reg, 44, 84, N); + + CudaRadixCiphertextFFI dest_b_slice; + slice_reg_batch_kreyvium(&dest_b_slice, s->b_reg, 0, 84, N); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_b_slice, &iv_source_for_b); + + CudaRadixCiphertextFFI iv_source_for_c; + slice_reg_batch_kreyvium(&iv_source_for_c, s->iv_reg, 0, 44, N); + + CudaRadixCiphertextFFI dest_c_iv_part; + slice_reg_batch_kreyvium(&dest_c_iv_part, s->c_reg, 67, 44, N); + copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + &dest_c_iv_part, &iv_source_for_c); + + reverse_bitsliced_radix_inplace_kreyvium(streams, mem, s->iv_reg, 128); + + CudaRadixCiphertextFFI dest_c_ones; + slice_reg_batch_kreyvium(&dest_c_ones, s->c_reg, 1, 66, N); + host_add_scalar_one_inplace(streams, &dest_c_ones, + mem->params.message_modulus, + mem->params.carry_modulus); + + integer_radix_apply_univariate_lookup_table( + streams, &dest_c_ones, &dest_c_ones, bsks, ksks, mem->luts->flush_lut, + dest_c_ones.num_radix_blocks); + + for (int i = 0; i < 18; i++) { + kreyvium_compute_64_steps(streams, mem, nullptr, bsks, ksks); + } +} + +template +__host__ void host_kreyvium_generate_keystream( + CudaStreams streams, CudaRadixCiphertextFFI *keystream_output, + CudaRadixCiphertextFFI const *key_bitsliced, + CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs, + uint32_t num_steps, int_kreyvium_buffer *mem, void *const *bsks, + uint64_t *const *ksks) { + + kreyvium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks); + + uint32_t num_batches = num_steps / 64; + for (uint32_t i = 0; i < num_batches; i++) { + CudaRadixCiphertextFFI batch_out_slice; + slice_reg_batch_kreyvium(&batch_out_slice, keystream_output, i * 64, + 64, num_inputs); + kreyvium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks); + } +} + +template +uint64_t scratch_cuda_kreyvium_encrypt(CudaStreams streams, + int_kreyvium_buffer **mem_ptr, + int_radix_params params, + bool allocate_gpu_memory, + uint32_t num_inputs) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_kreyvium_buffer( + streams, params, allocate_gpu_memory, num_inputs, size_tracker); + return size_tracker; +} + +#endif diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index f858a9d2e..f54ceff1e 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -2531,6 +2531,42 @@ unsafe extern "C" { unsafe extern "C" { pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } +unsafe extern "C" { + pub fn scratch_cuda_kreyvium_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + num_inputs: u32, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_kreyvium_generate_keystream_64( + streams: CudaStreamsFFI, + keystream_output: *mut CudaRadixCiphertextFFI, + key: *const CudaRadixCiphertextFFI, + iv: *const CudaRadixCiphertextFFI, + num_inputs: u32, + num_steps: u32, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_kreyvium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); +} pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0; pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1; pub type KS_TYPE = ffi::c_uint; diff --git a/backends/tfhe-cuda-backend/wrapper.h b/backends/tfhe-cuda-backend/wrapper.h index d4f6f9fc8..6c24c306c 100644 --- a/backends/tfhe-cuda-backend/wrapper.h +++ b/backends/tfhe-cuda-backend/wrapper.h @@ -5,6 +5,7 @@ #include "cuda/include/integer/rerand.h" #include "cuda/include/aes/aes.h" #include "cuda/include/trivium/trivium.h" +#include "cuda/include/kreyvium/kreyvium.h" #include "cuda/include/zk/zk.h" #include "cuda/include/keyswitch/keyswitch.h" #include "cuda/include/keyswitch/ks_enums.h" diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 3292412aa..f75c0ec52 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -10518,3 +10518,93 @@ pub(crate) unsafe fn cuda_backend_trivium_generate_keystream( + streams: &CudaStreams, + keystream_output: &mut CudaRadixCiphertext, + key: &CudaRadixCiphertext, + iv: &CudaRadixCiphertext, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + grouping_factor: LweBskGroupingFactor, + pbs_type: PBSType, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, + num_steps: u32, +) { + let mut keystream_degrees = keystream_output + .info + .blocks + .iter() + .map(|b| b.degree.0) + .collect(); + let mut keystream_noise_levels = keystream_output + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut cuda_ffi_keystream = prepare_cuda_radix_ffi( + keystream_output, + &mut keystream_degrees, + &mut keystream_noise_levels, + ); + + let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels); + + let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels); + + let num_inputs = (key.info.blocks.len() / 128) as u32; + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + scratch_cuda_kreyvium_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + num_inputs, + ); + + cuda_kreyvium_generate_keystream_64( + streams.ffi(), + &raw mut cuda_ffi_keystream, + &raw const cuda_ffi_key, + &raw const cuda_ffi_iv, + num_inputs, + num_steps, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_kreyvium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(keystream_output, &cuda_ffi_keystream); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/kreyvium.rs b/tfhe/src/integer/gpu/server_key/radix/kreyvium.rs new file mode 100644 index 000000000..06b593779 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/kreyvium.rs @@ -0,0 +1,187 @@ +use crate::core_crypto::gpu::CudaStreams; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::server_key::{ + CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey, +}; +use crate::integer::gpu::{ + cuda_backend_kreyvium_generate_keystream, LweBskGroupingFactor, PBSType, +}; +use crate::integer::{RadixCiphertext, RadixClientKey}; +use crate::shortint::Ciphertext; + +impl RadixClientKey { + pub fn encrypt_bits_for_kreyvium(&self, bits: &[u64]) -> RadixCiphertext { + let mut blocks: Vec = Vec::with_capacity(bits.len()); + for &bit in bits { + let mut ct = self.encrypt(bit); + let block = ct.blocks.pop().unwrap(); + blocks.push(block); + } + RadixCiphertext::from(blocks) + } + + pub fn decrypt_bits_from_kreyvium(&self, encrypted_stream: &RadixCiphertext) -> Vec { + let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len()); + for block in &encrypted_stream.blocks { + let tmp_radix = RadixCiphertext::from(vec![block.clone()]); + let val: u64 = self.decrypt(&tmp_radix); + decrypted_bits.push(val as u8); + } + decrypted_bits + } +} + +impl CudaServerKey { + pub fn kreyvium_generate_keystream( + &self, + key: &CudaUnsignedRadixCiphertext, + iv: &CudaUnsignedRadixCiphertext, + num_steps: usize, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let num_key_bits = 128; + let num_iv_bits = 128; + + assert_eq!( + key.as_ref().d_blocks.lwe_ciphertext_count().0, + num_key_bits, + "Input key must contain {} encrypted bits, but contains {}", + num_key_bits, + key.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + assert_eq!( + iv.as_ref().d_blocks.lwe_ciphertext_count().0, + num_iv_bits, + "Input IV must contain {} encrypted bits, but contains {}", + num_iv_bits, + iv.as_ref().d_blocks.lwe_ciphertext_count().0 + ); + + let num_output_bits = num_steps; + let mut keystream: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_output_bits, streams); + + let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { + panic!("Only the standard atomic pattern is supported on GPU") + }; + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_kreyvium_generate_keystream( + streams, + keystream.as_mut(), + key.as_ref(), + iv.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + d_bsk.input_lwe_dimension, + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + PBSType::Classical, + d_bsk.ms_noise_reduction_configuration.as_ref(), + num_steps as u32, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_kreyvium_generate_keystream( + streams, + keystream.as_mut(), + key.as_ref(), + iv.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + d_multibit_bsk.input_lwe_dimension, + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + PBSType::MultiBit, + None, + num_steps as u32, + ); + } + } + } + keystream + } +} + +#[cfg(test)] +mod tests { + use crate::core_crypto::gpu::CudaStreams; + use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + use crate::integer::gpu::CudaServerKey; + use crate::integer::keycache::KEY_CACHE; + use crate::integer::{IntegerKeyKind, RadixClientKey}; + use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + + #[test] + fn test_gpu_kreyvium_correctness() { + let streams = CudaStreams::new_multi_gpu(); + + let param = PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + + let (raw_cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cpu_cks = RadixClientKey::from((raw_cks, 1)); + + let sks = CudaServerKey::new(&cpu_cks, &streams); + + let key_hex = "0053A6F94C9FF24598EB000000000000"; + let iv_hex = "0D74DB42A91077DE45AC000000000000"; + let expected_out_hex = "D1F0303482061111"; + + let parse_hex = |s: &str| -> Vec { + let mut bits = Vec::new(); + for i in (0..s.len()).step_by(2) { + let byte = u8::from_str_radix(&s[i..i + 2], 16).unwrap(); + for j in 0..8 { + bits.push(((byte >> j) & 1) as u64); + } + } + bits + }; + + let key_bits = parse_hex(key_hex); + let iv_bits = parse_hex(iv_hex); + + assert_eq!(key_bits.len(), 128); + assert_eq!(iv_bits.len(), 128); + + let encrypted_key = cpu_cks.encrypt_bits_for_kreyvium(&key_bits); + let encrypted_iv = cpu_cks.encrypt_bits_for_kreyvium(&iv_bits); + + let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_key, &streams); + let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_iv, &streams); + + let d_keystream = sks.kreyvium_generate_keystream(&d_key, &d_iv, 64, &streams); + + let keystream = d_keystream.to_radix_ciphertext(&streams); + let decrypted_bits = cpu_cks.decrypt_bits_from_kreyvium(&keystream); + + let mut result_hex = String::new(); + for chunk in decrypted_bits.chunks(8) { + let mut byte = 0u8; + for (j, &b) in chunk.iter().enumerate() { + if b == 1 { + byte |= 1 << j; + } + } + result_hex.push_str(&format!("{:02X}", byte)); + } + + assert_eq!(result_hex, expected_out_hex); + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 2ce07d824..1eaa8240f 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -58,6 +58,7 @@ mod vector_find; mod aes; mod aes256; +mod kreyvium; #[cfg(test)] mod tests_long_run; #[cfg(test)]