mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): trivium
This commit is contained in:
7
Makefile
7
Makefile
@@ -1454,6 +1454,13 @@ bench_integer_aes256_gpu: install_rs_check_toolchain
|
||||
--bench integer-aes256 \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
|
||||
bench_integer_trivium_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-trivium \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_multi_bit # Run benchmarks for unsigned integer using multi-bit parameters
|
||||
bench_integer_multi_bit: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_PARAM_TYPE=MULTI_BIT __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
|
||||
@@ -86,6 +86,7 @@ fn main() {
|
||||
"cuda/include/integer/integer.h",
|
||||
"cuda/include/integer/rerand.h",
|
||||
"cuda/include/aes/aes.h",
|
||||
"cuda/include/trivium/trivium.h",
|
||||
"cuda/include/zk/zk.h",
|
||||
"cuda/include/keyswitch/keyswitch.h",
|
||||
"cuda/include/keyswitch/ks_enums.h",
|
||||
|
||||
24
backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h
Normal file
24
backends/tfhe-cuda-backend/cuda/include/trivium/trivium.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef TRIVIUM_H
|
||||
#define TRIVIUM_H
|
||||
|
||||
#include "../integer/integer.h"
|
||||
|
||||
extern "C" {
|
||||
uint64_t scratch_cuda_trivium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
|
||||
|
||||
void cuda_trivium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,277 @@
|
||||
#ifndef TRIVIUM_UTILITIES_H
|
||||
#define TRIVIUM_UTILITIES_H
|
||||
#include "../integer/integer_utilities.h"
|
||||
|
||||
template <typename Torus> struct int_trivium_lut_buffers {
|
||||
int_radix_lut<Torus> *and_lut;
|
||||
int_radix_lut<Torus> *flush_lut;
|
||||
|
||||
int_trivium_lut_buffers(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_trivium_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
constexpr uint32_t BATCH_SIZE = 64;
|
||||
constexpr uint32_t MAX_AND_PER_STEP = 3;
|
||||
uint32_t total_lut_ops = num_trivium_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
|
||||
|
||||
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus, Torus)> and_lambda =
|
||||
[](Torus a, Torus b) -> Torus { return a & b; };
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
|
||||
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, and_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_and =
|
||||
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
|
||||
this->and_lut->broadcast_lut(active_streams_and);
|
||||
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
uint32_t total_flush_ops = num_trivium_inputs * BATCH_SIZE * 4;
|
||||
|
||||
this->flush_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
|
||||
return x & 1;
|
||||
};
|
||||
|
||||
generate_device_accumulator(
|
||||
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
|
||||
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, flush_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_flush =
|
||||
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
|
||||
this->flush_lut->broadcast_lut(active_streams_flush);
|
||||
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
this->and_lut->release(streams);
|
||||
delete this->and_lut;
|
||||
this->and_lut = nullptr;
|
||||
|
||||
this->flush_lut->release(streams);
|
||||
delete this->flush_lut;
|
||||
this->flush_lut = nullptr;
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_trivium_state_workspaces {
|
||||
CudaRadixCiphertextFFI *a_reg;
|
||||
CudaRadixCiphertextFFI *b_reg;
|
||||
CudaRadixCiphertextFFI *c_reg;
|
||||
|
||||
CudaRadixCiphertextFFI *shift_workspace;
|
||||
|
||||
CudaRadixCiphertextFFI *temp_t1;
|
||||
CudaRadixCiphertextFFI *temp_t2;
|
||||
CudaRadixCiphertextFFI *temp_t3;
|
||||
CudaRadixCiphertextFFI *new_a;
|
||||
CudaRadixCiphertextFFI *new_b;
|
||||
CudaRadixCiphertextFFI *new_c;
|
||||
|
||||
CudaRadixCiphertextFFI *packed_pbs_lhs;
|
||||
CudaRadixCiphertextFFI *packed_pbs_rhs;
|
||||
CudaRadixCiphertextFFI *packed_pbs_out;
|
||||
|
||||
CudaRadixCiphertextFFI *packed_flush_in;
|
||||
CudaRadixCiphertextFFI *packed_flush_out;
|
||||
|
||||
int_trivium_state_workspaces(CudaStreams streams,
|
||||
const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
this->a_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->b_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->c_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->shift_workspace = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
|
||||
128 * num_inputs, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
uint32_t batch_blocks = 64 * num_inputs;
|
||||
|
||||
this->temp_t1 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t2 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t3 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_a = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_b = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_c = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_in = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams, bool allocate_gpu_memory) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->a_reg, allocate_gpu_memory);
|
||||
delete this->a_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->b_reg, allocate_gpu_memory);
|
||||
delete this->b_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->c_reg, allocate_gpu_memory);
|
||||
delete this->c_reg;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->shift_workspace, allocate_gpu_memory);
|
||||
delete this->shift_workspace;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t1, allocate_gpu_memory);
|
||||
delete this->temp_t1;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t2, allocate_gpu_memory);
|
||||
delete this->temp_t2;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->temp_t3, allocate_gpu_memory);
|
||||
delete this->temp_t3;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_a, allocate_gpu_memory);
|
||||
delete this->new_a;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_b, allocate_gpu_memory);
|
||||
delete this->new_b;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->new_c, allocate_gpu_memory);
|
||||
delete this->new_c;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_lhs, allocate_gpu_memory);
|
||||
delete this->packed_pbs_lhs;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_rhs, allocate_gpu_memory);
|
||||
delete this->packed_pbs_rhs;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_pbs_out, allocate_gpu_memory);
|
||||
delete this->packed_pbs_out;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_flush_in, allocate_gpu_memory);
|
||||
delete this->packed_flush_in;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->packed_flush_out, allocate_gpu_memory);
|
||||
delete this->packed_flush_out;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_trivium_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
uint32_t num_inputs;
|
||||
|
||||
int_trivium_lut_buffers<Torus> *luts;
|
||||
int_trivium_state_workspaces<Torus> *state;
|
||||
|
||||
int_trivium_buffer(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_inputs = num_inputs;
|
||||
|
||||
this->luts = new int_trivium_lut_buffers<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
|
||||
this->state = new int_trivium_state_workspaces<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
luts->release(streams);
|
||||
delete luts;
|
||||
luts = nullptr;
|
||||
|
||||
state->release(streams, allocate_gpu_memory);
|
||||
delete state;
|
||||
state = nullptr;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
45
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu
Normal file
45
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cu
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "../../include/trivium/trivium.h"
|
||||
#include "trivium.cuh"
|
||||
|
||||
uint64_t scratch_cuda_trivium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
ks_level, ks_base_log, pbs_level, pbs_base_log,
|
||||
grouping_factor, message_modulus, carry_modulus,
|
||||
noise_reduction_type);
|
||||
|
||||
return scratch_cuda_trivium_encrypt<uint64_t>(
|
||||
CudaStreams(streams), (int_trivium_buffer<uint64_t> **)mem_ptr, params,
|
||||
allocate_gpu_memory, num_inputs);
|
||||
}
|
||||
|
||||
void cuda_trivium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
auto buffer = (int_trivium_buffer<uint64_t> *)mem_ptr;
|
||||
|
||||
host_trivium_generate_keystream<uint64_t>(
|
||||
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
|
||||
buffer, bsks, (uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_trivium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
|
||||
int_trivium_buffer<uint64_t> *mem_ptr =
|
||||
(int_trivium_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
311
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh
Normal file
311
backends/tfhe-cuda-backend/cuda/src/trivium/trivium.cuh
Normal file
@@ -0,0 +1,311 @@
|
||||
#ifndef TRIVIUM_CUH
|
||||
#define TRIVIUM_CUH
|
||||
|
||||
#include "../../include/trivium/trivium_utilities.h"
|
||||
#include "../integer/integer.cuh"
|
||||
#include "../integer/radix_ciphertext.cuh"
|
||||
#include "../integer/scalar_addition.cuh"
|
||||
#include "../linearalgebra/addition.cuh"
|
||||
|
||||
template <typename Torus>
|
||||
void reverse_bitsliced_radix_inplace(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *radix,
|
||||
uint32_t num_bits_in_reg) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
|
||||
uint32_t src_start = i * N;
|
||||
uint32_t src_end = (i + 1) * N;
|
||||
|
||||
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
|
||||
uint32_t dest_end = (num_bits_in_reg - i) * N;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
|
||||
radix, src_start, src_end);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
|
||||
temp, 0, num_bits_in_reg * N);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __forceinline__ void
|
||||
trivium_xor(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs,
|
||||
const CudaRadixCiphertextFFI *rhs) {
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), out, lhs, rhs,
|
||||
out->num_radix_blocks, mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __forceinline__ void
|
||||
trivium_flush(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *target, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, target, target, bsks, ksks, mem->luts->flush_lut,
|
||||
target->num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void slice_reg_batch(CudaRadixCiphertextFFI *slice,
|
||||
const CudaRadixCiphertextFFI *reg,
|
||||
uint32_t start_bit_idx, uint32_t num_bits,
|
||||
uint32_t num_inputs) {
|
||||
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
|
||||
(start_bit_idx + num_bits) * num_inputs);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void shift_and_insert_batch(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *reg,
|
||||
CudaRadixCiphertextFFI *new_bits,
|
||||
uint32_t reg_size, uint32_t num_inputs) {
|
||||
|
||||
constexpr uint32_t BATCH = 64;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, 0, BATCH * num_inputs,
|
||||
new_bits, 0, BATCH * num_inputs);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, BATCH * num_inputs,
|
||||
reg_size * num_inputs, reg, 0, num_blocks_to_keep);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
|
||||
temp, 0, reg_size * num_inputs);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
trivium_compute_64_steps(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *output_dest, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
|
||||
uint32_t N = mem->num_inputs;
|
||||
constexpr uint32_t BATCH = 64;
|
||||
uint32_t batch_size_blocks = BATCH * N;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
|
||||
slice_reg_batch<Torus>(&a65_slice, s->a_reg, 2, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a92_slice, s->a_reg, 29, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a91_slice, s->a_reg, 28, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a90_slice, s->a_reg, 27, BATCH, N);
|
||||
slice_reg_batch<Torus>(&a68_slice, s->a_reg, 5, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
|
||||
slice_reg_batch<Torus>(&b68_slice, s->b_reg, 5, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b83_slice, s->b_reg, 20, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b82_slice, s->b_reg, 19, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b81_slice, s->b_reg, 18, BATCH, N);
|
||||
slice_reg_batch<Torus>(&b77_slice, s->b_reg, 14, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
|
||||
c86_slice;
|
||||
slice_reg_batch<Torus>(&c65_slice, s->c_reg, 2, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c110_slice, s->c_reg, 47, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c109_slice, s->c_reg, 46, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c108_slice, s->c_reg, 45, BATCH, N);
|
||||
slice_reg_batch<Torus>(&c86_slice, s->c_reg, 23, BATCH, N);
|
||||
|
||||
trivium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice);
|
||||
trivium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice);
|
||||
trivium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
|
||||
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
|
||||
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table<Torus>(
|
||||
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
|
||||
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
|
||||
mem->params.message_modulus);
|
||||
|
||||
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
trivium_xor(streams, mem, s->new_a, s->temp_t3, &a68_slice);
|
||||
trivium_xor(streams, mem, s->new_a, s->new_a, &and_res_a);
|
||||
|
||||
trivium_xor(streams, mem, s->new_b, s->temp_t1, &b77_slice);
|
||||
trivium_xor(streams, mem, s->new_b, s->new_b, &and_res_b);
|
||||
|
||||
trivium_xor(streams, mem, s->new_c, s->temp_t2, &c86_slice);
|
||||
trivium_xor(streams, mem, s->new_c, s->new_c, &and_res_c);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
trivium_xor(streams, mem, output_dest, s->temp_t1, s->temp_t2);
|
||||
trivium_xor(streams, mem, output_dest, output_dest, s->temp_t3);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
|
||||
batch_size_blocks, s->new_a, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
uint32_t total_flush_blocks = 3 * batch_size_blocks;
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
|
||||
batch_size_blocks);
|
||||
total_flush_blocks += batch_size_blocks;
|
||||
}
|
||||
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
|
||||
mem->luts->flush_lut, total_flush_blocks);
|
||||
|
||||
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
shift_and_insert_batch(streams, mem, s->a_reg, &flushed_a, 93, N);
|
||||
shift_and_insert_batch(streams, mem, s->b_reg, &flushed_b, 84, N);
|
||||
shift_and_insert_batch(streams, mem, s->c_reg, &flushed_c, 111, N);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
CudaRadixCiphertextFFI flushed_out;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
|
||||
3 * batch_size_blocks,
|
||||
4 * batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output_dest, 0,
|
||||
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, output_dest, 64);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void trivium_init(CudaStreams streams, int_trivium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced,
|
||||
void *const *bsks, uint64_t *const *ksks) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI src_key_slice;
|
||||
slice_reg_batch<Torus>(&src_key_slice, key_bitsliced, 0, 80, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_a_slice;
|
||||
slice_reg_batch<Torus>(&dest_a_slice, s->a_reg, 0, 80, N);
|
||||
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_a_slice, &src_key_slice);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->a_reg, 80);
|
||||
|
||||
CudaRadixCiphertextFFI src_iv_slice;
|
||||
slice_reg_batch<Torus>(&src_iv_slice, iv_bitsliced, 0, 80, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_b_slice;
|
||||
slice_reg_batch<Torus>(&dest_b_slice, s->b_reg, 0, 80, N);
|
||||
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_b_slice, &src_iv_slice);
|
||||
|
||||
reverse_bitsliced_radix_inplace<Torus>(streams, mem, s->b_reg, 80);
|
||||
|
||||
CudaRadixCiphertextFFI dest_c_ones;
|
||||
slice_reg_batch<Torus>(&dest_c_ones, s->c_reg, 108, 3, N);
|
||||
|
||||
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
trivium_flush(streams, mem, &dest_c_ones, bsks, ksks);
|
||||
|
||||
for (int i = 0; i < 18; i++) {
|
||||
trivium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_trivium_generate_keystream(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
|
||||
uint32_t num_steps, int_trivium_buffer<Torus> *mem, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
|
||||
trivium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
|
||||
|
||||
uint32_t num_batches = num_steps / 64;
|
||||
for (uint32_t i = 0; i < num_batches; i++) {
|
||||
CudaRadixCiphertextFFI batch_out_slice;
|
||||
slice_reg_batch<Torus>(&batch_out_slice, keystream_output, i * 64, 64,
|
||||
num_inputs);
|
||||
trivium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_trivium_encrypt(CudaStreams streams,
|
||||
int_trivium_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params,
|
||||
bool allocate_gpu_memory,
|
||||
uint32_t num_inputs) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_trivium_buffer<Torus>(streams, params, allocate_gpu_memory,
|
||||
num_inputs, size_tracker);
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2495,6 +2495,42 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_trivium_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
num_inputs: u32,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_trivium_generate_keystream_64(
|
||||
streams: CudaStreamsFFI,
|
||||
keystream_output: *mut CudaRadixCiphertextFFI,
|
||||
key: *const CudaRadixCiphertextFFI,
|
||||
iv: *const CudaRadixCiphertextFFI,
|
||||
num_inputs: u32,
|
||||
num_steps: u32,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
|
||||
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
|
||||
pub type KS_TYPE = ffi::c_uint;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "cuda/include/integer/integer.h"
|
||||
#include "cuda/include/integer/rerand.h"
|
||||
#include "cuda/include/aes/aes.h"
|
||||
#include "cuda/include/trivium/trivium.h"
|
||||
#include "cuda/include/zk/zk.h"
|
||||
#include "cuda/include/keyswitch/keyswitch.h"
|
||||
#include "cuda/include/keyswitch/ks_enums.h"
|
||||
|
||||
@@ -133,6 +133,12 @@ path = "benches/integer/aes.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "integer-trivium"
|
||||
path = "benches/integer/trivium.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "integer-aes256"
|
||||
path = "benches/integer/aes256.rs"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
mod aes;
|
||||
mod aes256;
|
||||
mod oprf;
|
||||
mod trivium;
|
||||
mod vector_find;
|
||||
|
||||
mod rerand;
|
||||
|
||||
100
tfhe-benchmark/benches/integer/trivium.rs
Normal file
100
tfhe-benchmark/benches/integer/trivium.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use criterion::Criterion;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod cuda {
|
||||
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
use benchmark::utilities::{write_to_json, OperatorType};
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use tfhe::core_crypto::gpu::CudaStreams;
|
||||
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use tfhe::integer::gpu::CudaServerKey;
|
||||
use tfhe::integer::keycache::KEY_CACHE;
|
||||
use tfhe::integer::{IntegerKeyKind, RadixClientKey};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::AtomicPatternParameters;
|
||||
|
||||
pub fn cuda_trivium(c: &mut Criterion) {
|
||||
let bench_name = "integer::cuda::trivium";
|
||||
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60))
|
||||
.warm_up_time(std::time::Duration::from_secs(5));
|
||||
|
||||
let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let atomic_param: AtomicPatternParameters = param.into();
|
||||
|
||||
let key_bits = vec![0u64; 80];
|
||||
let iv_bits = vec![0u64; 80];
|
||||
|
||||
let param_name = param.name();
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
let (cpu_cks, _) = KEY_CACHE.get_from_params(atomic_param, IntegerKeyKind::Radix);
|
||||
let sks = CudaServerKey::new(&cpu_cks, &streams);
|
||||
let cks = RadixClientKey::from((cpu_cks, 1));
|
||||
|
||||
let ct_key = cks.encrypt_bits_for_trivium(&key_bits);
|
||||
let ct_iv = cks.encrypt_bits_for_trivium(&iv_bits);
|
||||
|
||||
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
|
||||
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
|
||||
|
||||
{
|
||||
let num_steps = 64;
|
||||
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams));
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
atomic_param,
|
||||
param.name(),
|
||||
"trivium_generation_64_bits",
|
||||
&OperatorType::Atomic,
|
||||
80,
|
||||
vec![atomic_param.message_modulus().0.ilog2(); 80],
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let num_steps = 512;
|
||||
let bench_id = format!("{bench_name}::{param_name}::generate_{num_steps}_bits");
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
black_box(sks.trivium_generate_keystream(&d_key, &d_iv, num_steps, &streams));
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
atomic_param,
|
||||
param.name(),
|
||||
"trivium_generation_512_bits",
|
||||
&OperatorType::Atomic,
|
||||
80,
|
||||
vec![atomic_param.message_modulus().0.ilog2(); 80],
|
||||
);
|
||||
}
|
||||
|
||||
bench_group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(gpu_trivium, cuda_trivium);
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use cuda::gpu_trivium;
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu_trivium();
|
||||
|
||||
Criterion::default().configure_from_args().final_summary();
|
||||
}
|
||||
@@ -5,8 +5,6 @@ use std::{env, fs};
|
||||
#[cfg(feature = "gpu")]
|
||||
use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms};
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
#[cfg(feature = "integer")]
|
||||
use tfhe::prelude::*;
|
||||
|
||||
#[cfg(feature = "boolean")]
|
||||
pub mod boolean_utils {
|
||||
|
||||
@@ -10423,3 +10423,98 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async(
|
||||
carry_modulus.0 as u32,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
keystream_output: &mut CudaRadixCiphertext,
|
||||
key: &CudaRadixCiphertext,
|
||||
iv: &CudaRadixCiphertext,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
pbs_type: PBSType,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
num_steps: u32,
|
||||
) {
|
||||
let mut keystream_degrees = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.0)
|
||||
.collect();
|
||||
let mut keystream_noise_levels = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
|
||||
keystream_output,
|
||||
&mut keystream_degrees,
|
||||
&mut keystream_noise_levels,
|
||||
);
|
||||
|
||||
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
|
||||
|
||||
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
|
||||
|
||||
let num_inputs = (key.info.blocks.len() / 80) as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_trivium_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
num_inputs,
|
||||
);
|
||||
|
||||
cuda_trivium_generate_keystream_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_keystream,
|
||||
&raw const cuda_ffi_key,
|
||||
&raw const cuda_ffi_iv,
|
||||
num_inputs,
|
||||
num_steps,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_trivium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(keystream_output, &cuda_ffi_keystream);
|
||||
}
|
||||
|
||||
@@ -66,6 +66,7 @@ mod tests_noise_distribution;
|
||||
mod tests_signed;
|
||||
#[cfg(test)]
|
||||
mod tests_unsigned;
|
||||
mod trivium;
|
||||
|
||||
impl CudaServerKey {
|
||||
/// Create a trivial ciphertext filled with zeros on the GPU.
|
||||
|
||||
@@ -20,6 +20,7 @@ pub(crate) mod test_scalar_shift;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_shift;
|
||||
pub(crate) mod test_sub;
|
||||
pub(crate) mod test_trivium;
|
||||
pub(crate) mod test_vector_comparisons;
|
||||
pub(crate) mod test_vector_find;
|
||||
|
||||
@@ -85,6 +86,47 @@ impl<F> GpuFunctionExecutor<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>
|
||||
for GpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
usize,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a RadixCiphertext, &'a RadixCiphertext, usize),
|
||||
) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(
|
||||
&context.sks,
|
||||
&d_ctxt_1,
|
||||
&d_ctxt_2,
|
||||
input.2,
|
||||
&context.streams,
|
||||
);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext, u128, usize, usize),
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::test_trivium::{
|
||||
trivium_comparison_test, trivium_test_vector_1_test, trivium_test_vector_2_test,
|
||||
trivium_test_vector_3_test, trivium_test_vector_4_test,
|
||||
};
|
||||
use crate::shortint::parameters::{
|
||||
TestParameters, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
};
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_1 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_2 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_3 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_test_vector_4 {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
create_gpu_parameterized_test!(integer_trivium_comparison {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
fn integer_trivium_test_vector_1<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_1_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_2<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_2_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_3<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_3_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_test_vector_4<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_test_vector_4_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_trivium_comparison<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::trivium_generate_keystream);
|
||||
trivium_comparison_test(param, executor);
|
||||
}
|
||||
125
tfhe/src/integer/gpu/server_key/radix/trivium.rs
Normal file
125
tfhe/src/integer/gpu/server_key/radix/trivium.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{cuda_backend_trivium_generate_keystream, LweBskGroupingFactor, PBSType};
|
||||
use crate::integer::{RadixCiphertext, RadixClientKey};
|
||||
use crate::shortint::Ciphertext;
|
||||
|
||||
impl RadixClientKey {
|
||||
pub fn encrypt_bits_for_trivium(&self, bits: &[u64]) -> RadixCiphertext {
|
||||
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(bits.len());
|
||||
for &bit in bits {
|
||||
let mut ct = self.encrypt(bit);
|
||||
let block = ct.blocks.pop().unwrap();
|
||||
blocks.push(block);
|
||||
}
|
||||
RadixCiphertext::from(blocks)
|
||||
}
|
||||
|
||||
pub fn decrypt_bits_from_trivium(&self, encrypted_stream: &RadixCiphertext) -> Vec<u8> {
|
||||
let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len());
|
||||
for block in &encrypted_stream.blocks {
|
||||
let tmp_radix = RadixCiphertext::from(vec![block.clone()]);
|
||||
let val: u64 = self.decrypt(&tmp_radix);
|
||||
decrypted_bits.push(val as u8);
|
||||
}
|
||||
decrypted_bits
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaServerKey {
|
||||
/// Generates a Trivium keystream homomorphically on the GPU.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The encrypted secret key.
|
||||
/// * `iv` - The encrypted initialization vector.
|
||||
/// * `num_steps` - The number of keystream bits to generate per input.
|
||||
/// * `streams` - The CUDA streams to use for execution.
|
||||
pub fn trivium_generate_keystream(
|
||||
&self,
|
||||
key: &CudaUnsignedRadixCiphertext,
|
||||
iv: &CudaUnsignedRadixCiphertext,
|
||||
num_steps: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext {
|
||||
let num_key_bits = 80;
|
||||
let num_iv_bits = 80;
|
||||
|
||||
assert_eq!(
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
num_key_bits,
|
||||
"Input key must contain {} encrypted bits, but contains {}",
|
||||
num_key_bits,
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
assert_eq!(
|
||||
iv.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
num_iv_bits,
|
||||
"Input IV must contain {} encrypted bits, but contains {}",
|
||||
num_iv_bits,
|
||||
iv.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let num_output_bits = num_steps;
|
||||
let mut keystream: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_output_bits, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_trivium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_trivium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
keystream
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ pub(crate) mod test_shift;
|
||||
pub(crate) mod test_slice;
|
||||
pub(crate) mod test_sub;
|
||||
pub(crate) mod test_sum;
|
||||
pub(crate) mod test_trivium;
|
||||
pub(crate) mod test_vector_comparisons;
|
||||
pub(crate) mod test_vector_find;
|
||||
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
#![cfg(feature = "gpu")]
|
||||
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
|
||||
use crate::shortint::parameters::TestParameters;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TriviumRef {
|
||||
a: Vec<u8>,
|
||||
b: Vec<u8>,
|
||||
c: Vec<u8>,
|
||||
}
|
||||
|
||||
impl TriviumRef {
|
||||
fn new(key: &[u8], iv: &[u8]) -> Self {
|
||||
let mut a = vec![0u8; 93];
|
||||
let mut b = vec![0u8; 84];
|
||||
let mut c = vec![0u8; 111];
|
||||
|
||||
for i in 0..80 {
|
||||
a[i] = key[79 - i];
|
||||
b[i] = iv[79 - i];
|
||||
}
|
||||
|
||||
c[108] = 1;
|
||||
c[109] = 1;
|
||||
c[110] = 1;
|
||||
|
||||
let mut triv = Self { a, b, c };
|
||||
for _ in 0..(18 * 64) {
|
||||
triv.next();
|
||||
}
|
||||
triv
|
||||
}
|
||||
|
||||
fn next(&mut self) -> u8 {
|
||||
let t1 = self.a[65] ^ self.a[92];
|
||||
let t2 = self.b[68] ^ self.b[83];
|
||||
let t3 = self.c[65] ^ self.c[110];
|
||||
|
||||
let out = t1 ^ t2 ^ t3;
|
||||
|
||||
let a_in = t3 ^ self.a[68] ^ (self.c[108] & self.c[109]);
|
||||
let b_in = t1 ^ self.b[77] ^ (self.a[90] & self.a[91]);
|
||||
let c_in = t2 ^ self.c[86] ^ (self.b[81] & self.b[82]);
|
||||
|
||||
self.a.pop();
|
||||
self.a.insert(0, a_in);
|
||||
self.b.pop();
|
||||
self.b.insert(0, b_in);
|
||||
self.c.pop();
|
||||
self.c.insert(0, c_in);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn get_hexadecimal_string_from_lsb_first_stream(a: &[u8]) -> String {
|
||||
assert!(a.len().is_multiple_of(8));
|
||||
let mut hexadecimal = String::new();
|
||||
for test in a.chunks(8) {
|
||||
let to_hex = |chunk: &[u8]| -> char {
|
||||
let mut val = 0u8;
|
||||
if chunk[0] == 1 {
|
||||
val |= 1;
|
||||
}
|
||||
if chunk[1] == 1 {
|
||||
val |= 2;
|
||||
}
|
||||
if chunk[2] == 1 {
|
||||
val |= 4;
|
||||
}
|
||||
if chunk[3] == 1 {
|
||||
val |= 8;
|
||||
}
|
||||
|
||||
match val {
|
||||
0..=9 => (val + b'0') as char,
|
||||
10..=15 => (val - 10 + b'A') as char,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
hexadecimal.push(to_hex(&test[4..8]));
|
||||
hexadecimal.push(to_hex(&test[0..4]));
|
||||
}
|
||||
hexadecimal
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_1_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key = vec![0u64; 80];
|
||||
let iv = vec![0u64; 80];
|
||||
|
||||
let expected_output_0_63 = "FBE0BF265859051B517A2E4E239FC97F563203161907CF2DE7A8790FA1B2E9CDF75292030268B7382B4C1A759AA2599A285549986E74805903801A4CB5A5D4F2";
|
||||
|
||||
let cpu_key = cks.encrypt_bits_for_trivium(&key);
|
||||
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
|
||||
|
||||
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_2_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let mut key = vec![0u64; 80];
|
||||
let iv = vec![0u64; 80];
|
||||
key[7] = 1;
|
||||
|
||||
let expected_output_0_63 = "38EB86FF730D7A9CAF8DF13A4420540DBB7B651464C87501552041C249F29A64D2FBF515610921EBE06C8F92CECF7F8098FF20CCCC6A62B97BE8EF7454FC80F9";
|
||||
|
||||
let cpu_key = cks.encrypt_bits_for_trivium(&key);
|
||||
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
|
||||
|
||||
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_3_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key = vec![0u64; 80];
|
||||
let mut iv = vec![0u64; 80];
|
||||
iv[7] = 1;
|
||||
|
||||
let expected_output_0_63 = "F8901736640549E3BA7D42EA2D07B9F49233C18D773008BD755585B1A8CBAB86C1E9A9B91F1AD33483FD6EE3696D659C9374260456A36AAE11F033A519CBD5D7";
|
||||
|
||||
let cpu_key = cks.encrypt_bits_for_trivium(&key);
|
||||
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
|
||||
|
||||
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_test_vector_4_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let key_string = "0053A6F94C9FF24598EB";
|
||||
let mut key = vec![0u64; 80];
|
||||
for i in (0..key_string.len()).step_by(2) {
|
||||
let mut val = u8::from_str_radix(&key_string[i..i + 2], 16).unwrap();
|
||||
for j in 0..8 {
|
||||
key[8 * (i >> 1) + j] = (val % 2) as u64;
|
||||
val >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
let iv_string = "0D74DB42A91077DE45AC";
|
||||
let mut iv = vec![0u64; 80];
|
||||
for i in (0..iv_string.len()).step_by(2) {
|
||||
let mut val = u8::from_str_radix(&iv_string[i..i + 2], 16).unwrap();
|
||||
for j in 0..8 {
|
||||
iv[8 * (i >> 1) + j] = (val % 2) as u64;
|
||||
val >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
let expected_output_0_63 = "F4CD954A717F26A7D6930830C4E7CF0819F80E03F25F342C64ADC66ABA7F8A8E6EAA49F23632AE3CD41A7BD290A0132F81C6D4043B6E397D7388F3A03B5FE358";
|
||||
|
||||
let cpu_key = cks.encrypt_bits_for_trivium(&key);
|
||||
let cpu_iv = cks.encrypt_bits_for_trivium(&iv);
|
||||
|
||||
let num_steps = 512;
|
||||
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
|
||||
|
||||
let decrypted_bits = cks.decrypt_bits_from_trivium(&output_radix);
|
||||
let hex_string = get_hexadecimal_string_from_lsb_first_stream(&decrypted_bits);
|
||||
|
||||
assert_eq!(expected_output_0_63, &hex_string[0..64 * 2]);
|
||||
}
|
||||
|
||||
pub fn trivium_comparison_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext, usize), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, 1));
|
||||
let sks = Arc::new(sks);
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let num_runs = 5;
|
||||
let num_steps = 512;
|
||||
|
||||
for i in 0..num_runs {
|
||||
let mut rng = rand::thread_rng();
|
||||
let plain_key: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
|
||||
let plain_iv: Vec<u8> = (0..80).map(|_| rng.gen_range(0..=1)).collect();
|
||||
|
||||
let key_bits_u64: Vec<u64> = plain_key.iter().map(|&x| x as u64).collect();
|
||||
let iv_bits_u64: Vec<u64> = plain_iv.iter().map(|&x| x as u64).collect();
|
||||
|
||||
let cpu_key = cks.encrypt_bits_for_trivium(&key_bits_u64);
|
||||
let cpu_iv = cks.encrypt_bits_for_trivium(&iv_bits_u64);
|
||||
|
||||
let mut cpu_trivium = TriviumRef::new(&plain_key, &plain_iv);
|
||||
let mut cpu_output = Vec::with_capacity(num_steps);
|
||||
for _ in 0..num_steps {
|
||||
cpu_output.push(cpu_trivium.next());
|
||||
}
|
||||
|
||||
let output_radix = executor.execute((&cpu_key, &cpu_iv, num_steps));
|
||||
let fhe_output = cks.decrypt_bits_from_trivium(&output_radix);
|
||||
|
||||
assert_eq!(cpu_output.len(), fhe_output.len());
|
||||
assert_eq!(cpu_output, fhe_output, "Mismatch at iteration {i}");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user