feat(gpu): kreyvium

This commit is contained in:
Enzo Di Maria
2025-12-29 11:14:16 +01:00
parent ef33b7555c
commit bcbdf1f5d7
10 changed files with 1017 additions and 0 deletions

View File

@@ -87,6 +87,7 @@ fn main() {
"cuda/include/integer/rerand.h",
"cuda/include/aes/aes.h",
"cuda/include/trivium/trivium.h",
"cuda/include/kreyvium/kreyvium.h",
"cuda/include/zk/zk.h",
"cuda/include/keyswitch/keyswitch.h",
"cuda/include/keyswitch/ks_enums.h",

View File

@@ -0,0 +1,24 @@
#ifndef KREYVIUM_H
#define KREYVIUM_H
#include "../integer/integer.h"
extern "C" {
uint64_t scratch_cuda_kreyvium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
void cuda_kreyvium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
}
#endif

View File

@@ -0,0 +1,256 @@
#ifndef KREYVIUM_UTILITIES_H
#define KREYVIUM_UTILITIES_H
#include "../integer/integer_utilities.h"
template <typename Torus> struct int_kreyvium_lut_buffers {
int_radix_lut<Torus> *and_lut;
int_radix_lut<Torus> *flush_lut;
int_kreyvium_lut_buffers(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
constexpr uint32_t BATCH_SIZE = 64;
constexpr uint32_t MAX_AND_PER_STEP = 3;
uint32_t total_lut_ops = num_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
allocate_gpu_memory, size_tracker);
std::function<Torus(Torus, Torus)> and_lambda =
[](Torus a, Torus b) -> Torus { return a & b; };
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, and_lambda, allocate_gpu_memory);
auto active_streams_and =
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
this->and_lut->broadcast_lut(active_streams_and);
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
uint32_t total_flush_ops = num_inputs * BATCH_SIZE * 4;
this->flush_lut = new int_radix_lut<Torus>(
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
return x & 1;
};
generate_device_accumulator(
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, flush_lambda, allocate_gpu_memory);
auto active_streams_flush =
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
this->flush_lut->broadcast_lut(active_streams_flush);
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
}
void release(CudaStreams streams) {
this->and_lut->release(streams);
delete this->and_lut;
this->and_lut = nullptr;
this->flush_lut->release(streams);
delete this->flush_lut;
this->flush_lut = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_kreyvium_state_workspaces {
CudaRadixCiphertextFFI *a_reg;
CudaRadixCiphertextFFI *b_reg;
CudaRadixCiphertextFFI *c_reg;
CudaRadixCiphertextFFI *k_reg;
CudaRadixCiphertextFFI *iv_reg;
CudaRadixCiphertextFFI *shift_workspace;
CudaRadixCiphertextFFI *temp_t1;
CudaRadixCiphertextFFI *temp_t2;
CudaRadixCiphertextFFI *temp_t3;
CudaRadixCiphertextFFI *new_a;
CudaRadixCiphertextFFI *new_b;
CudaRadixCiphertextFFI *new_c;
CudaRadixCiphertextFFI *packed_pbs_lhs;
CudaRadixCiphertextFFI *packed_pbs_rhs;
CudaRadixCiphertextFFI *packed_pbs_out;
CudaRadixCiphertextFFI *packed_flush_in;
CudaRadixCiphertextFFI *packed_flush_out;
int_kreyvium_state_workspaces(CudaStreams streams,
const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->a_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->b_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->c_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->k_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->k_reg, 128 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->iv_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->iv_reg, 128 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->shift_workspace = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
128 * num_inputs, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
uint32_t batch_blocks = 64 * num_inputs;
this->temp_t1 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t2 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t3 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_a = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_b = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_c = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_in = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
void release(CudaStreams streams, bool allocate_gpu_memory) {
auto release_and_delete = [&](CudaRadixCiphertextFFI *&ptr) {
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
ptr, allocate_gpu_memory);
delete ptr;
ptr = nullptr;
};
release_and_delete(this->a_reg);
release_and_delete(this->b_reg);
release_and_delete(this->c_reg);
release_and_delete(this->k_reg);
release_and_delete(this->iv_reg);
release_and_delete(this->shift_workspace);
release_and_delete(this->temp_t1);
release_and_delete(this->temp_t2);
release_and_delete(this->temp_t3);
release_and_delete(this->new_a);
release_and_delete(this->new_b);
release_and_delete(this->new_c);
release_and_delete(this->packed_pbs_lhs);
release_and_delete(this->packed_pbs_rhs);
release_and_delete(this->packed_pbs_out);
release_and_delete(this->packed_flush_in);
release_and_delete(this->packed_flush_out);
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_kreyvium_buffer {
int_radix_params params;
bool allocate_gpu_memory;
uint32_t num_inputs;
int_kreyvium_lut_buffers<Torus> *luts;
int_kreyvium_state_workspaces<Torus> *state;
int_kreyvium_buffer(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_inputs = num_inputs;
this->luts = new int_kreyvium_lut_buffers<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
this->state = new int_kreyvium_state_workspaces<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
}
void release(CudaStreams streams) {
luts->release(streams);
delete luts;
luts = nullptr;
state->release(streams, allocate_gpu_memory);
delete state;
state = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
#endif

View File

@@ -0,0 +1,45 @@
#include "../../include/kreyvium/kreyvium.h"
#include "kreyvium.cuh"
uint64_t scratch_cuda_kreyvium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
glwe_dimension * polynomial_size, lwe_dimension,
ks_level, ks_base_log, pbs_level, pbs_base_log,
grouping_factor, message_modulus, carry_modulus,
noise_reduction_type);
return scratch_cuda_kreyvium_encrypt<uint64_t>(
CudaStreams(streams), (int_kreyvium_buffer<uint64_t> **)mem_ptr, params,
allocate_gpu_memory, num_inputs);
}
void cuda_kreyvium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
auto buffer = (int_kreyvium_buffer<uint64_t> *)mem_ptr;
host_kreyvium_generate_keystream<uint64_t>(
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
buffer, bsks, (uint64_t *const *)ksks);
}
void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
int_kreyvium_buffer<uint64_t> *mem_ptr =
(int_kreyvium_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -0,0 +1,376 @@
#ifndef KREYVIUM_CUH
#define KREYVIUM_CUH
#include "../../include/kreyvium/kreyvium_utilities.h"
#include "../integer/bitwise_ops.cuh"
#include "../integer/integer.cuh"
#include "../integer/radix_ciphertext.cuh"
#include "../integer/scalar_addition.cuh"
#include "../linearalgebra/addition.cuh"
template <typename Torus>
void reverse_bitsliced_radix_inplace_kreyvium(CudaStreams streams,
int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *radix,
uint32_t num_bits_in_reg) {
uint32_t N = mem->num_inputs;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
uint32_t src_start = i * N;
uint32_t src_end = (i + 1) * N;
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
uint32_t dest_end = (num_bits_in_reg - i) * N;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
radix, src_start, src_end);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
temp, 0, num_bits_in_reg * N);
}
template <typename Torus>
__host__ __forceinline__ void
kreyvium_xor(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs,
const CudaRadixCiphertextFFI *rhs) {
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), out, lhs, rhs,
out->num_radix_blocks, mem->params.message_modulus,
mem->params.carry_modulus);
}
template <typename Torus>
__host__ void slice_reg_batch_kreyvium(CudaRadixCiphertextFFI *slice,
const CudaRadixCiphertextFFI *reg,
uint32_t start_bit_idx,
uint32_t num_bits, uint32_t num_inputs) {
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
(start_bit_idx + num_bits) * num_inputs);
}
template <typename Torus>
__host__ void rotate_128_batch(CudaStreams streams,
int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *reg) {
uint32_t N = mem->num_inputs;
uint32_t HALF_SIZE = 64 * N;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
copy_radix_ciphertext_slice_async<Torus>(streams.stream(0),
streams.gpu_index(0), temp, 0,
HALF_SIZE, reg, 0, HALF_SIZE);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), reg, 0, HALF_SIZE, reg,
HALF_SIZE, 2 * HALF_SIZE);
copy_radix_ciphertext_slice_async<Torus>(streams.stream(0),
streams.gpu_index(0), reg, HALF_SIZE,
2 * HALF_SIZE, temp, 0, HALF_SIZE);
}
template <typename Torus>
__host__ void shift_and_insert_batch_kreyvium(CudaStreams streams,
int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *reg,
CudaRadixCiphertextFFI *new_bits,
uint32_t reg_size,
uint32_t num_inputs) {
constexpr uint32_t BATCH = 64;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, 0, num_blocks_to_keep, reg,
BATCH * num_inputs, reg_size * num_inputs);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, num_blocks_to_keep,
reg_size * num_inputs, new_bits, 0, BATCH * num_inputs);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
temp, 0, reg_size * num_inputs);
}
template <typename Torus>
__host__ void
kreyvium_compute_64_steps(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *output_dest,
void *const *bsks, uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
constexpr uint32_t BATCH = 64;
uint32_t batch_size_blocks = BATCH * N;
auto s = mem->state;
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
slice_reg_batch_kreyvium<Torus>(&a65_slice, s->a_reg, 27, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&a92_slice, s->a_reg, 0, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&a91_slice, s->a_reg, 1, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&a90_slice, s->a_reg, 2, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&a68_slice, s->a_reg, 24, BATCH, N);
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
slice_reg_batch_kreyvium<Torus>(&b68_slice, s->b_reg, 15, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&b83_slice, s->b_reg, 0, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&b82_slice, s->b_reg, 1, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&b81_slice, s->b_reg, 2, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&b77_slice, s->b_reg, 6, BATCH, N);
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
c86_slice;
slice_reg_batch_kreyvium<Torus>(&c65_slice, s->c_reg, 45, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&c110_slice, s->c_reg, 0, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&c109_slice, s->c_reg, 1, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&c108_slice, s->c_reg, 2, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&c86_slice, s->c_reg, 24, BATCH, N);
CudaRadixCiphertextFFI k_slice, iv_slice;
slice_reg_batch_kreyvium<Torus>(&k_slice, s->k_reg, 0, BATCH, N);
slice_reg_batch_kreyvium<Torus>(&iv_slice, s->iv_reg, 0, BATCH, N);
kreyvium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice);
kreyvium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice);
kreyvium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice);
kreyvium_xor(streams, mem, s->temp_t3, s->temp_t3, &k_slice);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
batch_size_blocks);
integer_radix_apply_bivariate_lookup_table<Torus>(
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
mem->params.message_modulus);
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
kreyvium_xor(streams, mem, s->new_a, &and_res_a, &a68_slice);
kreyvium_xor(streams, mem, s->new_b, &and_res_b, &b77_slice);
kreyvium_xor(streams, mem, s->new_b, s->new_b, s->temp_t1);
kreyvium_xor(streams, mem, s->new_c, &and_res_c, &c86_slice);
kreyvium_xor(streams, mem, s->new_c, s->new_c, s->temp_t2);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
batch_size_blocks, s->new_a, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
2 * batch_size_blocks, 3 * batch_size_blocks, s->temp_t3, 0,
batch_size_blocks);
integer_radix_apply_univariate_lookup_table<Torus>(
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
mem->luts->flush_lut, 3 * batch_size_blocks);
CudaRadixCiphertextFFI flushed_new_a_pre, flushed_new_b_pre, flushed_temp_t3;
as_radix_ciphertext_slice<Torus>(&flushed_new_a_pre, s->packed_flush_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_new_b_pre, s->packed_flush_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_temp_t3, s->packed_flush_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
kreyvium_xor(streams, mem, s->new_a, &flushed_new_a_pre, &flushed_temp_t3);
kreyvium_xor(streams, mem, s->new_b, &flushed_new_b_pre, &iv_slice);
if (output_dest != nullptr) {
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
s->temp_t1, s->temp_t2, output_dest->num_radix_blocks,
mem->params.message_modulus,
mem->params.carry_modulus);
host_addition<Torus>(
streams.stream(0), streams.gpu_index(0), output_dest, output_dest,
&flushed_temp_t3, output_dest->num_radix_blocks,
mem->params.message_modulus, mem->params.carry_modulus);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
batch_size_blocks, s->new_a, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
batch_size_blocks);
uint32_t total_flush_blocks = 3 * batch_size_blocks;
if (output_dest != nullptr) {
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
batch_size_blocks);
total_flush_blocks += batch_size_blocks;
}
integer_radix_apply_univariate_lookup_table<Torus>(
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
mem->luts->flush_lut, total_flush_blocks);
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
shift_and_insert_batch_kreyvium(streams, mem, s->a_reg, &flushed_a, 93, N);
shift_and_insert_batch_kreyvium(streams, mem, s->b_reg, &flushed_b, 84, N);
shift_and_insert_batch_kreyvium(streams, mem, s->c_reg, &flushed_c, 111, N);
rotate_128_batch(streams, mem, s->k_reg);
rotate_128_batch(streams, mem, s->iv_reg);
if (output_dest != nullptr) {
CudaRadixCiphertextFFI flushed_out;
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
3 * batch_size_blocks,
4 * batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), output_dest, 0,
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
}
}
template <typename Torus>
__host__ void kreyvium_init(CudaStreams streams,
int_kreyvium_buffer<Torus> *mem,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced,
void *const *bsks, uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
auto s = mem->state;
CudaRadixCiphertextFFI src_key_slice;
slice_reg_batch_kreyvium<Torus>(&src_key_slice, key_bitsliced, 0, 128, N);
CudaRadixCiphertextFFI dest_k_reg_slice;
slice_reg_batch_kreyvium<Torus>(&dest_k_reg_slice, s->k_reg, 0, 128, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_k_reg_slice, &src_key_slice);
CudaRadixCiphertextFFI k_source_for_a;
slice_reg_batch_kreyvium<Torus>(&k_source_for_a, s->k_reg, 35, 93, N);
CudaRadixCiphertextFFI dest_a_slice;
slice_reg_batch_kreyvium<Torus>(&dest_a_slice, s->a_reg, 0, 93, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_a_slice, &k_source_for_a);
reverse_bitsliced_radix_inplace_kreyvium<Torus>(streams, mem, s->k_reg, 128);
CudaRadixCiphertextFFI src_iv_slice;
slice_reg_batch_kreyvium<Torus>(&src_iv_slice, iv_bitsliced, 0, 128, N);
CudaRadixCiphertextFFI dest_iv_reg_slice;
slice_reg_batch_kreyvium<Torus>(&dest_iv_reg_slice, s->iv_reg, 0, 128, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_iv_reg_slice, &src_iv_slice);
CudaRadixCiphertextFFI iv_source_for_b;
slice_reg_batch_kreyvium<Torus>(&iv_source_for_b, s->iv_reg, 44, 84, N);
CudaRadixCiphertextFFI dest_b_slice;
slice_reg_batch_kreyvium<Torus>(&dest_b_slice, s->b_reg, 0, 84, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_b_slice, &iv_source_for_b);
CudaRadixCiphertextFFI iv_source_for_c;
slice_reg_batch_kreyvium<Torus>(&iv_source_for_c, s->iv_reg, 0, 44, N);
CudaRadixCiphertextFFI dest_c_iv_part;
slice_reg_batch_kreyvium<Torus>(&dest_c_iv_part, s->c_reg, 67, 44, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_c_iv_part, &iv_source_for_c);
reverse_bitsliced_radix_inplace_kreyvium<Torus>(streams, mem, s->iv_reg, 128);
CudaRadixCiphertextFFI dest_c_ones;
slice_reg_batch_kreyvium<Torus>(&dest_c_ones, s->c_reg, 1, 66, N);
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
mem->params.message_modulus,
mem->params.carry_modulus);
integer_radix_apply_univariate_lookup_table<Torus>(
streams, &dest_c_ones, &dest_c_ones, bsks, ksks, mem->luts->flush_lut,
dest_c_ones.num_radix_blocks);
for (int i = 0; i < 18; i++) {
kreyvium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
}
}
template <typename Torus>
__host__ void host_kreyvium_generate_keystream(
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
uint32_t num_steps, int_kreyvium_buffer<Torus> *mem, void *const *bsks,
uint64_t *const *ksks) {
kreyvium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
uint32_t num_batches = num_steps / 64;
for (uint32_t i = 0; i < num_batches; i++) {
CudaRadixCiphertextFFI batch_out_slice;
slice_reg_batch_kreyvium<Torus>(&batch_out_slice, keystream_output, i * 64,
64, num_inputs);
kreyvium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
}
}
template <typename Torus>
uint64_t scratch_cuda_kreyvium_encrypt(CudaStreams streams,
int_kreyvium_buffer<Torus> **mem_ptr,
int_radix_params params,
bool allocate_gpu_memory,
uint32_t num_inputs) {
uint64_t size_tracker = 0;
*mem_ptr = new int_kreyvium_buffer<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
return size_tracker;
}
#endif

View File

@@ -2531,6 +2531,42 @@ unsafe extern "C" {
unsafe extern "C" {
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
unsafe extern "C" {
pub fn scratch_cuda_kreyvium_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
num_inputs: u32,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_kreyvium_generate_keystream_64(
streams: CudaStreamsFFI,
keystream_output: *mut CudaRadixCiphertextFFI,
key: *const CudaRadixCiphertextFFI,
iv: *const CudaRadixCiphertextFFI,
num_inputs: u32,
num_steps: u32,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_kreyvium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
pub type KS_TYPE = ffi::c_uint;

View File

@@ -5,6 +5,7 @@
#include "cuda/include/integer/rerand.h"
#include "cuda/include/aes/aes.h"
#include "cuda/include/trivium/trivium.h"
#include "cuda/include/kreyvium/kreyvium.h"
#include "cuda/include/zk/zk.h"
#include "cuda/include/keyswitch/keyswitch.h"
#include "cuda/include/keyswitch/ks_enums.h"

View File

@@ -10518,3 +10518,93 @@ pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger,
update_noise_degree(keystream_output, &cuda_ffi_keystream);
}
#[allow(clippy::too_many_arguments)]
pub(crate) unsafe fn cuda_backend_kreyvium_generate_keystream<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
keystream_output: &mut CudaRadixCiphertext,
key: &CudaRadixCiphertext,
iv: &CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
grouping_factor: LweBskGroupingFactor,
pbs_type: PBSType,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
num_steps: u32,
) {
let mut keystream_degrees = keystream_output
.info
.blocks
.iter()
.map(|b| b.degree.0)
.collect();
let mut keystream_noise_levels = keystream_output
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
keystream_output,
&mut keystream_degrees,
&mut keystream_noise_levels,
);
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
let num_inputs = (key.info.blocks.len() / 128) as u32;
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_kreyvium_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
num_inputs,
);
cuda_kreyvium_generate_keystream_64(
streams.ffi(),
&raw mut cuda_ffi_keystream,
&raw const cuda_ffi_key,
&raw const cuda_ffi_iv,
num_inputs,
num_steps,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_kreyvium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(keystream_output, &cuda_ffi_keystream);
}

View File

@@ -0,0 +1,187 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{
cuda_backend_kreyvium_generate_keystream, LweBskGroupingFactor, PBSType,
};
use crate::integer::{RadixCiphertext, RadixClientKey};
use crate::shortint::Ciphertext;
impl RadixClientKey {
pub fn encrypt_bits_for_kreyvium(&self, bits: &[u64]) -> RadixCiphertext {
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(bits.len());
for &bit in bits {
let mut ct = self.encrypt(bit);
let block = ct.blocks.pop().unwrap();
blocks.push(block);
}
RadixCiphertext::from(blocks)
}
pub fn decrypt_bits_from_kreyvium(&self, encrypted_stream: &RadixCiphertext) -> Vec<u8> {
let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len());
for block in &encrypted_stream.blocks {
let tmp_radix = RadixCiphertext::from(vec![block.clone()]);
let val: u64 = self.decrypt(&tmp_radix);
decrypted_bits.push(val as u8);
}
decrypted_bits
}
}
impl CudaServerKey {
pub fn kreyvium_generate_keystream(
&self,
key: &CudaUnsignedRadixCiphertext,
iv: &CudaUnsignedRadixCiphertext,
num_steps: usize,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let num_key_bits = 128;
let num_iv_bits = 128;
assert_eq!(
key.as_ref().d_blocks.lwe_ciphertext_count().0,
num_key_bits,
"Input key must contain {} encrypted bits, but contains {}",
num_key_bits,
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert_eq!(
iv.as_ref().d_blocks.lwe_ciphertext_count().0,
num_iv_bits,
"Input IV must contain {} encrypted bits, but contains {}",
num_iv_bits,
iv.as_ref().d_blocks.lwe_ciphertext_count().0
);
let num_output_bits = num_steps;
let mut keystream: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_output_bits, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_kreyvium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
num_steps as u32,
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_kreyvium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
num_steps as u32,
);
}
}
}
keystream
}
}
#[cfg(test)]
mod tests {
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::gpu::CudaServerKey;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::{IntegerKeyKind, RadixClientKey};
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
#[test]
fn test_gpu_kreyvium_correctness() {
let streams = CudaStreams::new_multi_gpu();
let param = PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let (raw_cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cpu_cks = RadixClientKey::from((raw_cks, 1));
let sks = CudaServerKey::new(&cpu_cks, &streams);
let key_hex = "0053A6F94C9FF24598EB000000000000";
let iv_hex = "0D74DB42A91077DE45AC000000000000";
let expected_out_hex = "D1F0303482061111";
let parse_hex = |s: &str| -> Vec<u64> {
let mut bits = Vec::new();
for i in (0..s.len()).step_by(2) {
let byte = u8::from_str_radix(&s[i..i + 2], 16).unwrap();
for j in 0..8 {
bits.push(((byte >> j) & 1) as u64);
}
}
bits
};
let key_bits = parse_hex(key_hex);
let iv_bits = parse_hex(iv_hex);
assert_eq!(key_bits.len(), 128);
assert_eq!(iv_bits.len(), 128);
let encrypted_key = cpu_cks.encrypt_bits_for_kreyvium(&key_bits);
let encrypted_iv = cpu_cks.encrypt_bits_for_kreyvium(&iv_bits);
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_key, &streams);
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_iv, &streams);
let d_keystream = sks.kreyvium_generate_keystream(&d_key, &d_iv, 64, &streams);
let keystream = d_keystream.to_radix_ciphertext(&streams);
let decrypted_bits = cpu_cks.decrypt_bits_from_kreyvium(&keystream);
let mut result_hex = String::new();
for chunk in decrypted_bits.chunks(8) {
let mut byte = 0u8;
for (j, &b) in chunk.iter().enumerate() {
if b == 1 {
byte |= 1 << j;
}
}
result_hex.push_str(&format!("{:02X}", byte));
}
assert_eq!(result_hex, expected_out_hex);
}
}

View File

@@ -58,6 +58,7 @@ mod vector_find;
mod aes;
mod aes256;
mod kreyvium;
#[cfg(test)]
mod tests_long_run;
#[cfg(test)]