feat(gpu): trivium

This commit is contained in:
Enzo Di Maria
2025-12-19 14:54:54 +01:00
parent 696f964ecf
commit de60417209
19 changed files with 1404 additions and 2 deletions

View File

@@ -1454,6 +1454,13 @@ bench_integer_aes256_gpu: install_rs_check_toolchain
--bench integer-aes256 \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
bench_integer_trivium_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-trivium \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_multi_bit # Run benchmarks for unsigned integer using multi-bit parameters
bench_integer_multi_bit: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_PARAM_TYPE=MULTI_BIT __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \

View File

@@ -86,6 +86,7 @@ fn main() {
"cuda/include/integer/integer.h",
"cuda/include/integer/rerand.h",
"cuda/include/aes/aes.h",
"cuda/include/trivium/trivium.h",
"cuda/include/zk/zk.h",
"cuda/include/keyswitch/keyswitch.h",
"cuda/include/keyswitch/ks_enums.h",

View File

@@ -0,0 +1,24 @@
#ifndef TRIVIUM_H
#define TRIVIUM_H
#include "../integer/integer.h"
extern "C" {
uint64_t scratch_cuda_trivium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
void cuda_trivium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks);
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
}
#endif

View File

@@ -0,0 +1,277 @@
#ifndef TRIVIUM_UTILITIES_H
#define TRIVIUM_UTILITIES_H
#include "../integer/integer_utilities.h"
template <typename Torus> struct int_trivium_lut_buffers {
int_radix_lut<Torus> *and_lut;
int_radix_lut<Torus> *flush_lut;
int_trivium_lut_buffers(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_trivium_inputs,
uint64_t &size_tracker) {
constexpr uint32_t BATCH_SIZE = 64;
constexpr uint32_t MAX_AND_PER_STEP = 3;
uint32_t total_lut_ops = num_trivium_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
allocate_gpu_memory, size_tracker);
std::function<Torus(Torus, Torus)> and_lambda =
[](Torus a, Torus b) -> Torus { return a & b; };
generate_device_accumulator_bivariate<Torus>(
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, and_lambda, allocate_gpu_memory);
auto active_streams_and =
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
this->and_lut->broadcast_lut(active_streams_and);
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
uint32_t total_flush_ops = num_trivium_inputs * BATCH_SIZE * 4;
this->flush_lut = new int_radix_lut<Torus>(
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
return x & 1;
};
generate_device_accumulator(
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, flush_lambda, allocate_gpu_memory);
auto active_streams_flush =
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
this->flush_lut->broadcast_lut(active_streams_flush);
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
}
void release(CudaStreams streams) {
this->and_lut->release(streams);
delete this->and_lut;
this->and_lut = nullptr;
this->flush_lut->release(streams);
delete this->flush_lut;
this->flush_lut = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_trivium_state_workspaces {
CudaRadixCiphertextFFI *a_reg;
CudaRadixCiphertextFFI *b_reg;
CudaRadixCiphertextFFI *c_reg;
CudaRadixCiphertextFFI *shift_workspace;
CudaRadixCiphertextFFI *temp_t1;
CudaRadixCiphertextFFI *temp_t2;
CudaRadixCiphertextFFI *temp_t3;
CudaRadixCiphertextFFI *new_a;
CudaRadixCiphertextFFI *new_b;
CudaRadixCiphertextFFI *new_c;
CudaRadixCiphertextFFI *packed_pbs_lhs;
CudaRadixCiphertextFFI *packed_pbs_rhs;
CudaRadixCiphertextFFI *packed_pbs_out;
CudaRadixCiphertextFFI *packed_flush_in;
CudaRadixCiphertextFFI *packed_flush_out;
int_trivium_state_workspaces(CudaStreams streams,
const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->a_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->b_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->c_reg = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->shift_workspace = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
128 * num_inputs, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
uint32_t batch_blocks = 64 * num_inputs;
this->temp_t1 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t2 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->temp_t3 = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_a = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_b = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->new_c = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_pbs_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_in = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
this->packed_flush_out = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
void release(CudaStreams streams, bool allocate_gpu_memory) {
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->a_reg, allocate_gpu_memory);
delete this->a_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->b_reg, allocate_gpu_memory);
delete this->b_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->c_reg, allocate_gpu_memory);
delete this->c_reg;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->shift_workspace, allocate_gpu_memory);
delete this->shift_workspace;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t1, allocate_gpu_memory);
delete this->temp_t1;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t2, allocate_gpu_memory);
delete this->temp_t2;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->temp_t3, allocate_gpu_memory);
delete this->temp_t3;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_a, allocate_gpu_memory);
delete this->new_a;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_b, allocate_gpu_memory);
delete this->new_b;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->new_c, allocate_gpu_memory);
delete this->new_c;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_lhs, allocate_gpu_memory);
delete this->packed_pbs_lhs;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_rhs, allocate_gpu_memory);
delete this->packed_pbs_rhs;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_pbs_out, allocate_gpu_memory);
delete this->packed_pbs_out;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_flush_in, allocate_gpu_memory);
delete this->packed_flush_in;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->packed_flush_out, allocate_gpu_memory);
delete this->packed_flush_out;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_trivium_buffer {
int_radix_params params;
bool allocate_gpu_memory;
uint32_t num_inputs;
int_trivium_lut_buffers<Torus> *luts;
int_trivium_state_workspaces<Torus> *state;
int_trivium_buffer(CudaStreams streams, const int_radix_params &params,
bool allocate_gpu_memory, uint32_t num_inputs,
uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_inputs = num_inputs;
this->luts = new int_trivium_lut_buffers<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
this->state = new int_trivium_state_workspaces<Torus>(
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
}
void release(CudaStreams streams) {
luts->release(streams);
delete luts;
luts = nullptr;
state->release(streams, allocate_gpu_memory);
delete state;
state = nullptr;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
#endif

View File

@@ -0,0 +1,45 @@
#include "../../include/trivium/trivium.h"
#include "trivium.cuh"
uint64_t scratch_cuda_trivium_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
glwe_dimension * polynomial_size, lwe_dimension,
ks_level, ks_base_log, pbs_level, pbs_base_log,
grouping_factor, message_modulus, carry_modulus,
noise_reduction_type);
return scratch_cuda_trivium_encrypt<uint64_t>(
CudaStreams(streams), (int_trivium_buffer<uint64_t> **)mem_ptr, params,
allocate_gpu_memory, num_inputs);
}
void cuda_trivium_generate_keystream_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
void *const *ksks) {
auto buffer = (int_trivium_buffer<uint64_t> *)mem_ptr;
host_trivium_generate_keystream<uint64_t>(
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
buffer, bsks, (uint64_t *const *)ksks);
}
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
int_trivium_buffer<uint64_t> *mem_ptr =
(int_trivium_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -0,0 +1,311 @@
#ifndef TRIVIUM_CUH
#define TRIVIUM_CUH
#include "../../include/trivium/trivium_utilities.h"
#include "../integer/integer.cuh"
#include "../integer/radix_ciphertext.cuh"
#include "../integer/scalar_addition.cuh"
#include "../linearalgebra/addition.cuh"
template <typename Torus>
void reverse_bitsliced_radix_inplace(CudaStreams streams,
int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *radix,
uint32_t num_bits_in_reg) {
uint32_t N = mem->num_inputs;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
uint32_t src_start = i * N;
uint32_t src_end = (i + 1) * N;
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
uint32_t dest_end = (num_bits_in_reg - i) * N;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
radix, src_start, src_end);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
temp, 0, num_bits_in_reg * N);
}
template <typename Torus>
__host__ __forceinline__ void
trivium_xor(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs,
const CudaRadixCiphertextFFI *rhs) {
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), out, lhs, rhs,
out->num_radix_blocks, mem->params.message_modulus,
mem->params.carry_modulus);
}
template <typename Torus>
__host__ __forceinline__ void
trivium_flush(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *target, void *const *bsks,
uint64_t *const *ksks) {
integer_radix_apply_univariate_lookup_table<Torus>(
streams, target, target, bsks, ksks, mem->luts->flush_lut,
target->num_radix_blocks);
}
template <typename Torus>
__host__ void slice_reg_batch(CudaRadixCiphertextFFI *slice,
const CudaRadixCiphertextFFI *reg,
uint32_t start_bit_idx, uint32_t num_bits,
uint32_t num_inputs) {
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
(start_bit_idx + num_bits) * num_inputs);
}
template <typename Torus>
__host__ void shift_and_insert_batch(CudaStreams streams,
int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *reg,
CudaRadixCiphertextFFI *new_bits,
uint32_t reg_size, uint32_t num_inputs) {
constexpr uint32_t BATCH = 64;
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, 0, BATCH * num_inputs,
new_bits, 0, BATCH * num_inputs);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), temp, BATCH * num_inputs,
reg_size * num_inputs, reg, 0, num_blocks_to_keep);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
temp, 0, reg_size * num_inputs);
}
template <typename Torus>
__host__ void
trivium_compute_64_steps(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI *output_dest, void *const *bsks,
uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
constexpr uint32_t BATCH = 64;
uint32_t batch_size_blocks = BATCH * N;
auto s = mem->state;
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
slice_reg_batch<Torus>(&a65_slice, s->a_reg, 2, BATCH, N);
slice_reg_batch<Torus>(&a92_slice, s->a_reg, 29, BATCH, N);
slice_reg_batch<Torus>(&a91_slice, s->a_reg, 28, BATCH, N);
slice_reg_batch<Torus>(&a90_slice, s->a_reg, 27, BATCH, N);
slice_reg_batch<Torus>(&a68_slice, s->a_reg, 5, BATCH, N);
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
slice_reg_batch<Torus>(&b68_slice, s->b_reg, 5, BATCH, N);
slice_reg_batch<Torus>(&b83_slice, s->b_reg, 20, BATCH, N);
slice_reg_batch<Torus>(&b82_slice, s->b_reg, 19, BATCH, N);
slice_reg_batch<Torus>(&b81_slice, s->b_reg, 18, BATCH, N);
slice_reg_batch<Torus>(&b77_slice, s->b_reg, 14, BATCH, N);
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
c86_slice;
slice_reg_batch<Torus>(&c65_slice, s->c_reg, 2, BATCH, N);
slice_reg_batch<Torus>(&c110_slice, s->c_reg, 47, BATCH, N);
slice_reg_batch<Torus>(&c109_slice, s->c_reg, 46, BATCH, N);
slice_reg_batch<Torus>(&c108_slice, s->c_reg, 45, BATCH, N);
slice_reg_batch<Torus>(&c86_slice, s->c_reg, 23, BATCH, N);
trivium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice);
trivium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice);
trivium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
batch_size_blocks);
integer_radix_apply_bivariate_lookup_table<Torus>(
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
mem->params.message_modulus);
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
trivium_xor(streams, mem, s->new_a, s->temp_t3, &a68_slice);
trivium_xor(streams, mem, s->new_a, s->new_a, &and_res_a);
trivium_xor(streams, mem, s->new_b, s->temp_t1, &b77_slice);
trivium_xor(streams, mem, s->new_b, s->new_b, &and_res_b);
trivium_xor(streams, mem, s->new_c, s->temp_t2, &c86_slice);
trivium_xor(streams, mem, s->new_c, s->new_c, &and_res_c);
if (output_dest != nullptr) {
trivium_xor(streams, mem, output_dest, s->temp_t1, s->temp_t2);
trivium_xor(streams, mem, output_dest, output_dest, s->temp_t3);
}
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
batch_size_blocks, s->new_a, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
batch_size_blocks);
uint32_t total_flush_blocks = 3 * batch_size_blocks;
if (output_dest != nullptr) {
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
batch_size_blocks);
total_flush_blocks += batch_size_blocks;
}
integer_radix_apply_univariate_lookup_table<Torus>(
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
mem->luts->flush_lut, total_flush_blocks);
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
batch_size_blocks, 2 * batch_size_blocks);
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
2 * batch_size_blocks,
3 * batch_size_blocks);
shift_and_insert_batch(streams, mem, s->a_reg, &flushed_a, 93, N);
shift_and_insert_batch(streams, mem, s->b_reg, &flushed_b, 84, N);
shift_and_insert_batch(streams, mem, s->c_reg, &flushed_c, 111, N);
if (output_dest != nullptr) {
CudaRadixCiphertextFFI flushed_out;
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
3 * batch_size_blocks,
4 * batch_size_blocks);
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), output_dest, 0,
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, output_dest, 64);
}
}
template <typename Torus>
__host__ void trivium_init(CudaStreams streams, int_trivium_buffer<Torus> *mem,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced,
void *const *bsks, uint64_t *const *ksks) {
uint32_t N = mem->num_inputs;
auto s = mem->state;
CudaRadixCiphertextFFI src_key_slice;
slice_reg_batch<Torus>(&src_key_slice, key_bitsliced, 0, 80, N);
CudaRadixCiphertextFFI dest_a_slice;
slice_reg_batch<Torus>(&dest_a_slice, s->a_reg, 0, 80, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_a_slice, &src_key_slice);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->a_reg, 80);
CudaRadixCiphertextFFI src_iv_slice;
slice_reg_batch<Torus>(&src_iv_slice, iv_bitsliced, 0, 80, N);
CudaRadixCiphertextFFI dest_b_slice;
slice_reg_batch<Torus>(&dest_b_slice, s->b_reg, 0, 80, N);
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
&dest_b_slice, &src_iv_slice);
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->b_reg, 80);
CudaRadixCiphertextFFI dest_c_ones;
slice_reg_batch<Torus>(&dest_c_ones, s->c_reg, 108, 3, N);
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
mem->params.message_modulus,
mem->params.carry_modulus);
trivium_flush(streams, mem, &dest_c_ones, bsks, ksks);
for (int i = 0; i < 18; i++) {
trivium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
}
}
template <typename Torus>
__host__ void host_trivium_generate_keystream(
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
CudaRadixCiphertextFFI const *key_bitsliced,
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
uint32_t num_steps, int_trivium_buffer<Torus> *mem, void *const *bsks,
uint64_t *const *ksks) {
trivium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
uint32_t num_batches = num_steps / 64;
for (uint32_t i = 0; i < num_batches; i++) {
CudaRadixCiphertextFFI batch_out_slice;
slice_reg_batch<Torus>(&batch_out_slice, keystream_output, i * 64, 64,
num_inputs);
trivium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
}
}
template <typename Torus>
uint64_t scratch_cuda_trivium_encrypt(CudaStreams streams,
int_trivium_buffer<Torus> **mem_ptr,
int_radix_params params,
bool allocate_gpu_memory,
uint32_t num_inputs) {
uint64_t size_tracker = 0;
*mem_ptr = new int_trivium_buffer<Torus>(streams, params, allocate_gpu_memory,
num_inputs, size_tracker);
return size_tracker;
}
#endif

View File

@@ -2495,6 +2495,42 @@ unsafe extern "C" {
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_trivium_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
num_inputs: u32,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_trivium_generate_keystream_64(
streams: CudaStreamsFFI,
keystream_output: *mut CudaRadixCiphertextFFI,
key: *const CudaRadixCiphertextFFI,
iv: *const CudaRadixCiphertextFFI,
num_inputs: u32,
num_steps: u32,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
pub type KS_TYPE = ffi::c_uint;

View File

@@ -4,6 +4,7 @@
#include "cuda/include/integer/integer.h"
#include "cuda/include/integer/rerand.h"
#include "cuda/include/aes/aes.h"
#include "cuda/include/trivium/trivium.h"
#include "cuda/include/zk/zk.h"
#include "cuda/include/keyswitch/keyswitch.h"
#include "cuda/include/keyswitch/ks_enums.h"

View File

@@ -133,6 +133,12 @@ path = "benches/integer/aes.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "integer-trivium"
path = "benches/integer/trivium.rs"
harness = false
required-features = ["integer", "internal-keycache"]
[[bench]]
name = "integer-aes256"
path = "benches/integer/aes256.rs"

View File

@@ -3,6 +3,7 @@
mod aes;
mod aes256;
mod oprf;
mod trivium;
mod vector_find;
mod rerand;

View File

@@ -0,0 +1,100 @@
use criterion::Criterion;
#[cfg(feature = "gpu")]
pub mod cuda {
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use benchmark::utilities::{write_to_json, OperatorType};
use criterion::{black_box, criterion_group, Criterion};
use tfhe::core_crypto::gpu::CudaStreams;
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::{IntegerKeyKind, RadixClientKey};
use tfhe::keycache::NamedParam;
use tfhe::shortint::AtomicPatternParameters;
pub fn cuda_trivium(c: &mut Criterion) {
let bench_name = "integer::cuda::trivium";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60))
.warm_up_time(std::time::Duration::from_secs(5));
let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let atomic_param: AtomicPatternParameters = param.into();
let key_bits = vec![0u64; 80];
let iv_bits = vec![0u64; 80];
let param_name = param.name();
let streams = CudaStreams::new_multi_gpu();
let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix);
let sks = CudaServerKey::new(&cpu_cks, &streams);
let cks = RadixClientKey::from((cpu_cks, 1));
let ct_key = cks.encrypt_bits_for_trivium(&key_bits);
let ct_iv = cks.encrypt_bits_for_trivium(&iv_bits);
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
{
let num_steps = 64;
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams));
})
});
write_to_json::<u64, _>(
&bench_id,
atomic_param,
param.name(),
"trivium_generation_64_bits",
&OperatorType::Atomic,
80,
vec![atomic_param.message_modulus().0.ilog2(); 80],
);
}
{
let num_steps = 512;
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams));
})
});
write_to_json::<u64, _>(
&bench_id,
atomic_param,
param.name(),
"trivium_generation_512_bits",
&OperatorType::Atomic,
80,
vec![atomic_param.message_modulus().0.ilog2(); 80],
);
}
bench_group.finish();
}
criterion_group!(gpu_trivium, cuda_trivium);
}
#[cfg(feature = "gpu")]
use cuda::gpu_trivium;
fn main() {
#[cfg(feature = "gpu")]
gpu_trivium();
Criterion::default().configure_from_args().final_summary();
}

View File

@@ -5,8 +5,6 @@ use std::{env, fs};
#[cfg(feature = "gpu")]
use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms};
use tfhe::core_crypto::prelude::*;
#[cfg(feature = "integer")]
use tfhe::prelude::*;
#[cfg(feature = "boolean")]
pub mod boolean_utils {

View File

@@ -10423,3 +10423,98 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async(
carry_modulus.0 as u32,
);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
keystream_output: &mut CudaRadixCiphertext,
key: &CudaRadixCiphertext,
iv: &CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
grouping_factor: LweBskGroupingFactor,
pbs_type: PBSType,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
num_steps: u32,
) {
let mut keystream_degrees = keystream_output
.info
.blocks
.iter()
.map(|b| b.degree.0)
.collect();
let mut keystream_noise_levels = keystream_output
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
keystream_output,
&mut keystream_degrees,
&mut keystream_noise_levels,
);
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
let num_inputs = (key.info.blocks.len() / 80) as u32;
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_trivium_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
num_inputs,
);
cuda_trivium_generate_keystream_64(
streams.ffi(),
&raw mut cuda_ffi_keystream,
&raw const cuda_ffi_key,
&raw const cuda_ffi_iv,
num_inputs,
num_steps,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_trivium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(keystream_output, &cuda_ffi_keystream);
}

View File

@@ -66,6 +66,7 @@ mod tests_noise_distribution;
mod tests_signed;
#[cfg(test)]
mod tests_unsigned;
mod trivium;
impl CudaServerKey {
/// Create a trivial ciphertext filled with zeros on the GPU.

View File

@@ -20,6 +20,7 @@ pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
pub(crate) mod test_sub;
pub(crate) mod test_trivium;
pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;
@@ -85,6 +86,47 @@ impl<F> GpuFunctionExecutor<F> {
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
usize,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a RadixCiphertext, &'a RadixCiphertext, usize),
) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let gpu_result = (self.func)(
&context.sks,
&d_ctxt_1,
&d_ctxt_2,
input.2,
&context.streams,
);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
impl<'a, F>
FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize),

View File

@@ -0,0 +1,71 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_unsigned::test_trivium::{
trivium_comparison_test, trivium_test_vector_1_test, trivium_test_vector_2_test,
trivium_test_vector_3_test, trivium_test_vector_4_test,
};
use crate::shortint::parameters::{
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
};
create_gpu_parameterized_test!(integer_trivium_test_vector_1 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_2 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_3 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_test_vector_4 {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
create_gpu_parameterized_test!(integer_trivium_comparison {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
fn integer_trivium_test_vector_1<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_1_test(param, executor);
}
fn integer_trivium_test_vector_2<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_2_test(param, executor);
}
fn integer_trivium_test_vector_3<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_3_test(param, executor);
}
fn integer_trivium_test_vector_4<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_test_vector_4_test(param, executor);
}
fn integer_trivium_comparison<P>(param: P)
where
P: Into<TestParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
trivium_comparison_test(param, executor);
}

View File

@@ -0,0 +1,125 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::integer::gpu::{cuda_backend_trivium_generate_keystream, LweBskGroupingFactor, PBSType};
use crate::integer::{RadixCiphertext, RadixClientKey};
use crate::shortint::Ciphertext;
impl RadixClientKey {
pub fn encrypt_bits_for_trivium(&self, bits: &[u64]) -> RadixCiphertext {
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(bits.len());
for &bit in bits {
let mut ct = self.encrypt(bit);
let block = ct.blocks.pop().unwrap();
blocks.push(block);
}
RadixCiphertext::from(blocks)
}
pub fn decrypt_bits_from_trivium(&self, encrypted_stream: &RadixCiphertext) -> Vec<u8> {
let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len());
for block in &encrypted_stream.blocks {
let tmp_radix = RadixCiphertext::from(vec![block.clone()]);
let val: u64 = self.decrypt(&tmp_radix);
decrypted_bits.push(val as u8);
}
decrypted_bits
}
}
impl CudaServerKey {
/// Generates a Trivium keystream homomorphically on the GPU.
///
/// # Arguments
/// * `key` - The encrypted secret key.
/// * `iv` - The encrypted initialization vector.
/// * `num_steps` - The number of keystream bits to generate per input.
/// * `streams` - The CUDA streams to use for execution.
pub fn trivium_generate_keystream(
&self,
key: &CudaUnsignedRadixCiphertext,
iv: &CudaUnsignedRadixCiphertext,
num_steps: usize,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let num_key_bits = 80;
let num_iv_bits = 80;
assert_eq!(
key.as_ref().d_blocks.lwe_ciphertext_count().0,
num_key_bits,
"Input key must contain {} encrypted bits, but contains {}",
num_key_bits,
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert_eq!(
iv.as_ref().d_blocks.lwe_ciphertext_count().0,
num_iv_bits,
"Input IV must contain {} encrypted bits, but contains {}",
num_iv_bits,
iv.as_ref().d_blocks.lwe_ciphertext_count().0
);
let num_output_bits = num_steps;
let mut keystream: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_output_bits, streams);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_trivium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
num_steps as u32,
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_trivium_generate_keystream(
streams,
keystream.as_mut(),
key.as_ref(),
iv.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
num_steps as u32,
);
}
}
}
keystream
}
}

View File

@@ -28,6 +28,7 @@ pub(crate) mod test_shift;
pub(crate) mod test_slice;
pub(crate) mod test_sub;
pub(crate) mod test_sum;
pub(crate) mod test_trivium;
pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;

View File

@@ -0,0 +1,260 @@
#![cfg(feature = "gpu")]
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
use crate::shortint::parameters::TestParameters;
use rand::Rng;
use std::sync::Arc;
struct TriviumRef {
a: Vec<u8>,
b: Vec<u8>,
c: Vec<u8>,
}
impl TriviumRef {
fn new(key: &[u8], iv: &[u8]) -> Self {
let mut a = vec![0u8; 93];
let mut b = vec![0u8; 84];
let mut c = vec![0u8; 111];
for i in 0..80 {
a[i] = key[79 - i];
b[i] = iv[79 - i];
}
c[108] = 1;
c[109] = 1;
c[110] = 1;
let mut triv = Self { a, b, c };
for _ in 0..(18 * 64) {
triv.next();
}
triv
}
fn next(&mut self) -> u8 {
let t1 = self.a[65] ^ self.a[92];
let t2 = self.b[68] ^ self.b[83];
let t3 = self.c[65] ^ self.c[110];
let out = t1 ^ t2 ^ t3;
let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]);
let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]);
let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]);
self.a.pop();
self.a.insert(0, a_in);
self.b.pop();
self.b.insert(0, b_in);
self.c.pop();
self.c.insert(0, c_in);
out
}
}
fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String {
assert!(a.len().is_multiple_of(8));
let mut hexadecimal = String::new();
for test in a.chunks(8) {
let to_hex = |chunk: &[u8]| -> char {
let mut val = 0u8;
if chunk[0] == 1 {
val |= 1;
}
if chunk[1] == 1 {
val |= 2;
}
if chunk[2] == 1 {
val |= 4;
}
if chunk[3] == 1 {
val |= 8;
}
match val {
0..=9 => (val + b'0') as char,
10..=15 => (val - 10 + b'A') as char,
_ => unreachable!(),
}
};
hexadecimal.push(to_hex(&test[4..8]));
hexadecimal.push(to_hex(&test[0..4]));
}
hexadecimal
}
pub fn trivium_test_vector_1_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key = vec![0u64; 80];
let iv = vec![0u64; 80];
let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2";
let cpu_key = cks.encrypt_bits_for_trivium(&key);
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
let num_steps = 512;
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_2_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let mut key = vec![0u64; 80];
let iv = vec![0u64; 80];
key[7] = 1;
let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9";
let cpu_key = cks.encrypt_bits_for_trivium(&key);
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
let num_steps = 512;
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_3_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key = vec![0u64; 80];
let mut iv = vec![0u64; 80];
iv[7] = 1;
let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7";
let cpu_key = cks.encrypt_bits_for_trivium(&key);
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
let num_steps = 512;
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_test_vector_4_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let key_string = "0053A6F94C9FF24598EB";
let mut key = vec![0u64; 80];
for i in (0..key_string.len()).step_by(2) {
let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap();
for j in 0..8 {
key[8 * (i >> 1) + j] = (val % 2) as u64;
val >>= 1;
}
}
let iv_string = "0D74DB42A91077DE45AC";
let mut iv = vec![0u64; 80];
for i in (0..iv_string.len()).step_by(2) {
let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap();
for j in 0..8 {
iv[8 * (i >> 1) + j] = (val % 2) as u64;
val >>= 1;
}
}
let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358";
let cpu_key = cks.encrypt_bits_for_trivium(&key);
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
let num_steps = 512;
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
}
pub fn trivium_comparison_test<P, E>(param: P, mut executor: E)
where
P: Into<TestParameters>,
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, 1));
let sks = Arc::new(sks);
executor.setup(&cks, sks);
let num_runs = 5;
let num_steps = 512;
for i in 0..num_runs {
let mut rng = rand::thread_rng();
let plain_key: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
let plain_iv: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
let key_bits_u64: Vec<u64> = plain_key.iter().map(|&x| x as u64).collect();
let iv_bits_u64: Vec<u64> = plain_iv.iter().map(|&x| x as u64).collect();
let cpu_key = cks.encrypt_bits_for_trivium(&key_bits_u64);
let cpu_iv = cks.encrypt_bits_for_trivium(&iv_bits_u64);
let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv);
let mut cpu_output = Vec::with_capacity(num_steps);
for _ in 0..num_steps {
cpu_output.push(cpu_trivium.next());
}
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
let fhe_output = cks.decrypt_bits_from_trivium(&output_radix);
assert_eq!(cpu_output.len(), fhe_output.len());
assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}");
}
}