mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(gpu): kreyvium
This commit is contained in:
@@ -87,6 +87,7 @@ fn main() {
|
||||
"cuda/include/integer/rerand.h",
|
||||
"cuda/include/aes/aes.h",
|
||||
"cuda/include/trivium/trivium.h",
|
||||
"cuda/include/kreyvium/kreyvium.h",
|
||||
"cuda/include/zk/zk.h",
|
||||
"cuda/include/keyswitch/keyswitch.h",
|
||||
"cuda/include/keyswitch/ks_enums.h",
|
||||
|
||||
24
backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h
Normal file
24
backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef KREYVIUM_H
|
||||
#define KREYVIUM_H
|
||||
|
||||
#include "../integer/integer.h"
|
||||
|
||||
extern "C" {
|
||||
uint64_t scratch_cuda_kreyvium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs);
|
||||
|
||||
void cuda_kreyvium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,256 @@
|
||||
#ifndef KREYVIUM_UTILITIES_H
|
||||
#define KREYVIUM_UTILITIES_H
|
||||
#include "../integer/integer_utilities.h"
|
||||
|
||||
template <typename Torus> struct int_kreyvium_lut_buffers {
|
||||
int_radix_lut<Torus> *and_lut;
|
||||
int_radix_lut<Torus> *flush_lut;
|
||||
|
||||
int_kreyvium_lut_buffers(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
constexpr uint32_t BATCH_SIZE = 64;
|
||||
constexpr uint32_t MAX_AND_PER_STEP = 3;
|
||||
uint32_t total_lut_ops = num_inputs * BATCH_SIZE * MAX_AND_PER_STEP;
|
||||
|
||||
this->and_lut = new int_radix_lut<Torus>(streams, params, 1, total_lut_ops,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus, Torus)> and_lambda =
|
||||
[](Torus a, Torus b) -> Torus { return a & b; };
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->and_lut->get_lut(0, 0),
|
||||
this->and_lut->get_degree(0), this->and_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, and_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_and =
|
||||
streams.active_gpu_subset(total_lut_ops, params.pbs_type);
|
||||
this->and_lut->broadcast_lut(active_streams_and);
|
||||
this->and_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
|
||||
uint32_t total_flush_ops = num_inputs * BATCH_SIZE * 4;
|
||||
|
||||
this->flush_lut = new int_radix_lut<Torus>(
|
||||
streams, params, 1, total_flush_ops, allocate_gpu_memory, size_tracker);
|
||||
|
||||
std::function<Torus(Torus)> flush_lambda = [](Torus x) -> Torus {
|
||||
return x & 1;
|
||||
};
|
||||
|
||||
generate_device_accumulator(
|
||||
streams.stream(0), streams.gpu_index(0), this->flush_lut->get_lut(0, 0),
|
||||
this->flush_lut->get_degree(0), this->flush_lut->get_max_degree(0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, flush_lambda, allocate_gpu_memory);
|
||||
|
||||
auto active_streams_flush =
|
||||
streams.active_gpu_subset(total_flush_ops, params.pbs_type);
|
||||
this->flush_lut->broadcast_lut(active_streams_flush);
|
||||
this->flush_lut->setup_gemm_batch_ks_temp_buffers(size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
this->and_lut->release(streams);
|
||||
delete this->and_lut;
|
||||
this->and_lut = nullptr;
|
||||
|
||||
this->flush_lut->release(streams);
|
||||
delete this->flush_lut;
|
||||
this->flush_lut = nullptr;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_kreyvium_state_workspaces {
|
||||
CudaRadixCiphertextFFI *a_reg;
|
||||
CudaRadixCiphertextFFI *b_reg;
|
||||
CudaRadixCiphertextFFI *c_reg;
|
||||
|
||||
CudaRadixCiphertextFFI *k_reg;
|
||||
CudaRadixCiphertextFFI *iv_reg;
|
||||
|
||||
CudaRadixCiphertextFFI *shift_workspace;
|
||||
|
||||
CudaRadixCiphertextFFI *temp_t1;
|
||||
CudaRadixCiphertextFFI *temp_t2;
|
||||
CudaRadixCiphertextFFI *temp_t3;
|
||||
CudaRadixCiphertextFFI *new_a;
|
||||
CudaRadixCiphertextFFI *new_b;
|
||||
CudaRadixCiphertextFFI *new_c;
|
||||
|
||||
CudaRadixCiphertextFFI *packed_pbs_lhs;
|
||||
CudaRadixCiphertextFFI *packed_pbs_rhs;
|
||||
CudaRadixCiphertextFFI *packed_pbs_out;
|
||||
|
||||
CudaRadixCiphertextFFI *packed_flush_in;
|
||||
CudaRadixCiphertextFFI *packed_flush_out;
|
||||
|
||||
int_kreyvium_state_workspaces(CudaStreams streams,
|
||||
const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
|
||||
this->a_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->a_reg, 93 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->b_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->b_reg, 84 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->c_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->c_reg, 111 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->k_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->k_reg, 128 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->iv_reg = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->iv_reg, 128 * num_inputs,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->shift_workspace = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->shift_workspace,
|
||||
128 * num_inputs, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
uint32_t batch_blocks = 64 * num_inputs;
|
||||
|
||||
this->temp_t1 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t1, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t2 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t2, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->temp_t3 = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->temp_t3, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_a = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_a, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_b = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_b, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->new_c = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->new_c, batch_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_lhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_lhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_rhs = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_rhs,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_pbs_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_pbs_out,
|
||||
3 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_in = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_in,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
this->packed_flush_out = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->packed_flush_out,
|
||||
4 * batch_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams, bool allocate_gpu_memory) {
|
||||
auto release_and_delete = [&](CudaRadixCiphertextFFI *&ptr) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
ptr, allocate_gpu_memory);
|
||||
delete ptr;
|
||||
ptr = nullptr;
|
||||
};
|
||||
|
||||
release_and_delete(this->a_reg);
|
||||
release_and_delete(this->b_reg);
|
||||
release_and_delete(this->c_reg);
|
||||
release_and_delete(this->k_reg);
|
||||
release_and_delete(this->iv_reg);
|
||||
release_and_delete(this->shift_workspace);
|
||||
release_and_delete(this->temp_t1);
|
||||
release_and_delete(this->temp_t2);
|
||||
release_and_delete(this->temp_t3);
|
||||
release_and_delete(this->new_a);
|
||||
release_and_delete(this->new_b);
|
||||
release_and_delete(this->new_c);
|
||||
release_and_delete(this->packed_pbs_lhs);
|
||||
release_and_delete(this->packed_pbs_rhs);
|
||||
release_and_delete(this->packed_pbs_out);
|
||||
release_and_delete(this->packed_flush_in);
|
||||
release_and_delete(this->packed_flush_out);
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_kreyvium_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
uint32_t num_inputs;
|
||||
|
||||
int_kreyvium_lut_buffers<Torus> *luts;
|
||||
int_kreyvium_state_workspaces<Torus> *state;
|
||||
|
||||
int_kreyvium_buffer(CudaStreams streams, const int_radix_params ¶ms,
|
||||
bool allocate_gpu_memory, uint32_t num_inputs,
|
||||
uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_inputs = num_inputs;
|
||||
|
||||
this->luts = new int_kreyvium_lut_buffers<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
|
||||
this->state = new int_kreyvium_state_workspaces<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
luts->release(streams);
|
||||
delete luts;
|
||||
luts = nullptr;
|
||||
|
||||
state->release(streams, allocate_gpu_memory);
|
||||
delete state;
|
||||
state = nullptr;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
45
backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu
Normal file
45
backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "../../include/kreyvium/kreyvium.h"
|
||||
#include "kreyvium.cuh"
|
||||
|
||||
uint64_t scratch_cuda_kreyvium_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
|
||||
uint32_t grouping_factor, uint32_t message_modulus, uint32_t carry_modulus,
|
||||
PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type, uint32_t num_inputs) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
ks_level, ks_base_log, pbs_level, pbs_base_log,
|
||||
grouping_factor, message_modulus, carry_modulus,
|
||||
noise_reduction_type);
|
||||
|
||||
return scratch_cuda_kreyvium_encrypt<uint64_t>(
|
||||
CudaStreams(streams), (int_kreyvium_buffer<uint64_t> **)mem_ptr, params,
|
||||
allocate_gpu_memory, num_inputs);
|
||||
}
|
||||
|
||||
void cuda_kreyvium_generate_keystream_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
const CudaRadixCiphertextFFI *key, const CudaRadixCiphertextFFI *iv,
|
||||
uint32_t num_inputs, uint32_t num_steps, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
auto buffer = (int_kreyvium_buffer<uint64_t> *)mem_ptr;
|
||||
|
||||
host_kreyvium_generate_keystream<uint64_t>(
|
||||
CudaStreams(streams), keystream_output, key, iv, num_inputs, num_steps,
|
||||
buffer, bsks, (uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_kreyvium_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
|
||||
int_kreyvium_buffer<uint64_t> *mem_ptr =
|
||||
(int_kreyvium_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
376
backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh
Normal file
376
backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh
Normal file
@@ -0,0 +1,376 @@
|
||||
#ifndef KREYVIUM_CUH
|
||||
#define KREYVIUM_CUH
|
||||
|
||||
#include "../../include/kreyvium/kreyvium_utilities.h"
|
||||
#include "../integer/bitwise_ops.cuh"
|
||||
#include "../integer/integer.cuh"
|
||||
#include "../integer/radix_ciphertext.cuh"
|
||||
#include "../integer/scalar_addition.cuh"
|
||||
#include "../linearalgebra/addition.cuh"
|
||||
|
||||
template <typename Torus>
|
||||
void reverse_bitsliced_radix_inplace_kreyvium(CudaStreams streams,
|
||||
int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *radix,
|
||||
uint32_t num_bits_in_reg) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
for (uint32_t i = 0; i < num_bits_in_reg; i++) {
|
||||
uint32_t src_start = i * N;
|
||||
uint32_t src_end = (i + 1) * N;
|
||||
|
||||
uint32_t dest_start = (num_bits_in_reg - 1 - i) * N;
|
||||
uint32_t dest_end = (num_bits_in_reg - i) * N;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, dest_start, dest_end,
|
||||
radix, src_start, src_end);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), radix, 0, num_bits_in_reg * N,
|
||||
temp, 0, num_bits_in_reg * N);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __forceinline__ void
|
||||
kreyvium_xor(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *out, const CudaRadixCiphertextFFI *lhs,
|
||||
const CudaRadixCiphertextFFI *rhs) {
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), out, lhs, rhs,
|
||||
out->num_radix_blocks, mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void slice_reg_batch_kreyvium(CudaRadixCiphertextFFI *slice,
|
||||
const CudaRadixCiphertextFFI *reg,
|
||||
uint32_t start_bit_idx,
|
||||
uint32_t num_bits, uint32_t num_inputs) {
|
||||
as_radix_ciphertext_slice<Torus>(slice, reg, start_bit_idx * num_inputs,
|
||||
(start_bit_idx + num_bits) * num_inputs);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void rotate_128_batch(CudaStreams streams,
|
||||
int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *reg) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
uint32_t HALF_SIZE = 64 * N;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(streams.stream(0),
|
||||
streams.gpu_index(0), temp, 0,
|
||||
HALF_SIZE, reg, 0, HALF_SIZE);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), reg, 0, HALF_SIZE, reg,
|
||||
HALF_SIZE, 2 * HALF_SIZE);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(streams.stream(0),
|
||||
streams.gpu_index(0), reg, HALF_SIZE,
|
||||
2 * HALF_SIZE, temp, 0, HALF_SIZE);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void shift_and_insert_batch_kreyvium(CudaStreams streams,
|
||||
int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *reg,
|
||||
CudaRadixCiphertextFFI *new_bits,
|
||||
uint32_t reg_size,
|
||||
uint32_t num_inputs) {
|
||||
constexpr uint32_t BATCH = 64;
|
||||
CudaRadixCiphertextFFI *temp = mem->state->shift_workspace;
|
||||
uint32_t num_blocks_to_keep = (reg_size - BATCH) * num_inputs;
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, 0, num_blocks_to_keep, reg,
|
||||
BATCH * num_inputs, reg_size * num_inputs);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), temp, num_blocks_to_keep,
|
||||
reg_size * num_inputs, new_bits, 0, BATCH * num_inputs);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), reg, 0, reg_size * num_inputs,
|
||||
temp, 0, reg_size * num_inputs);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
kreyvium_compute_64_steps(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI *output_dest,
|
||||
void *const *bsks, uint64_t *const *ksks) {
|
||||
|
||||
uint32_t N = mem->num_inputs;
|
||||
constexpr uint32_t BATCH = 64;
|
||||
uint32_t batch_size_blocks = BATCH * N;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI a65_slice, a92_slice, a91_slice, a90_slice, a68_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&a65_slice, s->a_reg, 27, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&a92_slice, s->a_reg, 0, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&a91_slice, s->a_reg, 1, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&a90_slice, s->a_reg, 2, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&a68_slice, s->a_reg, 24, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI b68_slice, b83_slice, b82_slice, b81_slice, b77_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&b68_slice, s->b_reg, 15, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&b83_slice, s->b_reg, 0, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&b82_slice, s->b_reg, 1, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&b81_slice, s->b_reg, 2, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&b77_slice, s->b_reg, 6, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI c65_slice, c110_slice, c109_slice, c108_slice,
|
||||
c86_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&c65_slice, s->c_reg, 45, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&c110_slice, s->c_reg, 0, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&c109_slice, s->c_reg, 1, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&c108_slice, s->c_reg, 2, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&c86_slice, s->c_reg, 24, BATCH, N);
|
||||
|
||||
CudaRadixCiphertextFFI k_slice, iv_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&k_slice, s->k_reg, 0, BATCH, N);
|
||||
slice_reg_batch_kreyvium<Torus>(&iv_slice, s->iv_reg, 0, BATCH, N);
|
||||
|
||||
kreyvium_xor(streams, mem, s->temp_t1, &a65_slice, &a92_slice);
|
||||
kreyvium_xor(streams, mem, s->temp_t2, &b68_slice, &b83_slice);
|
||||
kreyvium_xor(streams, mem, s->temp_t3, &c65_slice, &c110_slice);
|
||||
kreyvium_xor(streams, mem, s->temp_t3, s->temp_t3, &k_slice);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs, 0,
|
||||
batch_size_blocks, &c109_slice, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a91_slice, 0,
|
||||
batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_lhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b82_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs, 0,
|
||||
batch_size_blocks, &c108_slice, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
batch_size_blocks, 2 * batch_size_blocks, &a90_slice, 0,
|
||||
batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_pbs_rhs,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, &b81_slice, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table<Torus>(
|
||||
streams, s->packed_pbs_out, s->packed_pbs_lhs, s->packed_pbs_rhs, bsks,
|
||||
ksks, mem->luts->and_lut, 3 * batch_size_blocks,
|
||||
mem->params.message_modulus);
|
||||
|
||||
CudaRadixCiphertextFFI and_res_a, and_res_b, and_res_c;
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_a, s->packed_pbs_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_b, s->packed_pbs_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&and_res_c, s->packed_pbs_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
kreyvium_xor(streams, mem, s->new_a, &and_res_a, &a68_slice);
|
||||
kreyvium_xor(streams, mem, s->new_b, &and_res_b, &b77_slice);
|
||||
kreyvium_xor(streams, mem, s->new_b, s->new_b, s->temp_t1);
|
||||
kreyvium_xor(streams, mem, s->new_c, &and_res_c, &c86_slice);
|
||||
kreyvium_xor(streams, mem, s->new_c, s->new_c, s->temp_t2);
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
|
||||
batch_size_blocks, s->new_a, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, s->temp_t3, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
|
||||
mem->luts->flush_lut, 3 * batch_size_blocks);
|
||||
|
||||
CudaRadixCiphertextFFI flushed_new_a_pre, flushed_new_b_pre, flushed_temp_t3;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_new_a_pre, s->packed_flush_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_new_b_pre, s->packed_flush_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_temp_t3, s->packed_flush_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
kreyvium_xor(streams, mem, s->new_a, &flushed_new_a_pre, &flushed_temp_t3);
|
||||
kreyvium_xor(streams, mem, s->new_b, &flushed_new_b_pre, &iv_slice);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), output_dest,
|
||||
s->temp_t1, s->temp_t2, output_dest->num_radix_blocks,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
host_addition<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output_dest, output_dest,
|
||||
&flushed_temp_t3, output_dest->num_radix_blocks,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
}
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in, 0,
|
||||
batch_size_blocks, s->new_a, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
batch_size_blocks, 2 * batch_size_blocks, s->new_b, 0, batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
2 * batch_size_blocks, 3 * batch_size_blocks, s->new_c, 0,
|
||||
batch_size_blocks);
|
||||
|
||||
uint32_t total_flush_blocks = 3 * batch_size_blocks;
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), s->packed_flush_in,
|
||||
3 * batch_size_blocks, 4 * batch_size_blocks, output_dest, 0,
|
||||
batch_size_blocks);
|
||||
total_flush_blocks += batch_size_blocks;
|
||||
}
|
||||
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, s->packed_flush_out, s->packed_flush_in, bsks, ksks,
|
||||
mem->luts->flush_lut, total_flush_blocks);
|
||||
|
||||
CudaRadixCiphertextFFI flushed_a, flushed_b, flushed_c;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_a, s->packed_flush_out, 0,
|
||||
batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_b, s->packed_flush_out,
|
||||
batch_size_blocks, 2 * batch_size_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_c, s->packed_flush_out,
|
||||
2 * batch_size_blocks,
|
||||
3 * batch_size_blocks);
|
||||
|
||||
shift_and_insert_batch_kreyvium(streams, mem, s->a_reg, &flushed_a, 93, N);
|
||||
shift_and_insert_batch_kreyvium(streams, mem, s->b_reg, &flushed_b, 84, N);
|
||||
shift_and_insert_batch_kreyvium(streams, mem, s->c_reg, &flushed_c, 111, N);
|
||||
|
||||
rotate_128_batch(streams, mem, s->k_reg);
|
||||
rotate_128_batch(streams, mem, s->iv_reg);
|
||||
|
||||
if (output_dest != nullptr) {
|
||||
CudaRadixCiphertextFFI flushed_out;
|
||||
as_radix_ciphertext_slice<Torus>(&flushed_out, s->packed_flush_out,
|
||||
3 * batch_size_blocks,
|
||||
4 * batch_size_blocks);
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output_dest, 0,
|
||||
batch_size_blocks, &flushed_out, 0, batch_size_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void kreyvium_init(CudaStreams streams,
|
||||
int_kreyvium_buffer<Torus> *mem,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced,
|
||||
void *const *bsks, uint64_t *const *ksks) {
|
||||
uint32_t N = mem->num_inputs;
|
||||
auto s = mem->state;
|
||||
|
||||
CudaRadixCiphertextFFI src_key_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&src_key_slice, key_bitsliced, 0, 128, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_k_reg_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_k_reg_slice, s->k_reg, 0, 128, N);
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_k_reg_slice, &src_key_slice);
|
||||
|
||||
CudaRadixCiphertextFFI k_source_for_a;
|
||||
slice_reg_batch_kreyvium<Torus>(&k_source_for_a, s->k_reg, 35, 93, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_a_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_a_slice, s->a_reg, 0, 93, N);
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_a_slice, &k_source_for_a);
|
||||
|
||||
reverse_bitsliced_radix_inplace_kreyvium<Torus>(streams, mem, s->k_reg, 128);
|
||||
|
||||
CudaRadixCiphertextFFI src_iv_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&src_iv_slice, iv_bitsliced, 0, 128, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_iv_reg_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_iv_reg_slice, s->iv_reg, 0, 128, N);
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_iv_reg_slice, &src_iv_slice);
|
||||
|
||||
CudaRadixCiphertextFFI iv_source_for_b;
|
||||
slice_reg_batch_kreyvium<Torus>(&iv_source_for_b, s->iv_reg, 44, 84, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_b_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_b_slice, s->b_reg, 0, 84, N);
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_b_slice, &iv_source_for_b);
|
||||
|
||||
CudaRadixCiphertextFFI iv_source_for_c;
|
||||
slice_reg_batch_kreyvium<Torus>(&iv_source_for_c, s->iv_reg, 0, 44, N);
|
||||
|
||||
CudaRadixCiphertextFFI dest_c_iv_part;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_c_iv_part, s->c_reg, 67, 44, N);
|
||||
copy_radix_ciphertext_async<Torus>(streams.stream(0), streams.gpu_index(0),
|
||||
&dest_c_iv_part, &iv_source_for_c);
|
||||
|
||||
reverse_bitsliced_radix_inplace_kreyvium<Torus>(streams, mem, s->iv_reg, 128);
|
||||
|
||||
CudaRadixCiphertextFFI dest_c_ones;
|
||||
slice_reg_batch_kreyvium<Torus>(&dest_c_ones, s->c_reg, 1, 66, N);
|
||||
host_add_scalar_one_inplace<Torus>(streams, &dest_c_ones,
|
||||
mem->params.message_modulus,
|
||||
mem->params.carry_modulus);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, &dest_c_ones, &dest_c_ones, bsks, ksks, mem->luts->flush_lut,
|
||||
dest_c_ones.num_radix_blocks);
|
||||
|
||||
for (int i = 0; i < 18; i++) {
|
||||
kreyvium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_kreyvium_generate_keystream(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *keystream_output,
|
||||
CudaRadixCiphertextFFI const *key_bitsliced,
|
||||
CudaRadixCiphertextFFI const *iv_bitsliced, uint32_t num_inputs,
|
||||
uint32_t num_steps, int_kreyvium_buffer<Torus> *mem, void *const *bsks,
|
||||
uint64_t *const *ksks) {
|
||||
|
||||
kreyvium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
|
||||
|
||||
uint32_t num_batches = num_steps / 64;
|
||||
for (uint32_t i = 0; i < num_batches; i++) {
|
||||
CudaRadixCiphertextFFI batch_out_slice;
|
||||
slice_reg_batch_kreyvium<Torus>(&batch_out_slice, keystream_output, i * 64,
|
||||
64, num_inputs);
|
||||
kreyvium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_kreyvium_encrypt(CudaStreams streams,
|
||||
int_kreyvium_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params,
|
||||
bool allocate_gpu_memory,
|
||||
uint32_t num_inputs) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_kreyvium_buffer<Torus>(
|
||||
streams, params, allocate_gpu_memory, num_inputs, size_tracker);
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2531,6 +2531,42 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_trivium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_kreyvium_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
num_inputs: u32,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_kreyvium_generate_keystream_64(
|
||||
streams: CudaStreamsFFI,
|
||||
keystream_output: *mut CudaRadixCiphertextFFI,
|
||||
key: *const CudaRadixCiphertextFFI,
|
||||
iv: *const CudaRadixCiphertextFFI,
|
||||
num_inputs: u32,
|
||||
num_steps: u32,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_kreyvium_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
|
||||
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
|
||||
pub type KS_TYPE = ffi::c_uint;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "cuda/include/integer/rerand.h"
|
||||
#include "cuda/include/aes/aes.h"
|
||||
#include "cuda/include/trivium/trivium.h"
|
||||
#include "cuda/include/kreyvium/kreyvium.h"
|
||||
#include "cuda/include/zk/zk.h"
|
||||
#include "cuda/include/keyswitch/keyswitch.h"
|
||||
#include "cuda/include/keyswitch/ks_enums.h"
|
||||
|
||||
@@ -10518,3 +10518,93 @@ pub(crate) unsafe fn cuda_backend_trivium_generate_keystream<T: UnsignedInteger,
|
||||
|
||||
update_noise_degree(keystream_output, &cuda_ffi_keystream);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) unsafe fn cuda_backend_kreyvium_generate_keystream<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
keystream_output: &mut CudaRadixCiphertext,
|
||||
key: &CudaRadixCiphertext,
|
||||
iv: &CudaRadixCiphertext,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
pbs_type: PBSType,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
num_steps: u32,
|
||||
) {
|
||||
let mut keystream_degrees = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.0)
|
||||
.collect();
|
||||
let mut keystream_noise_levels = keystream_output
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
|
||||
keystream_output,
|
||||
&mut keystream_degrees,
|
||||
&mut keystream_noise_levels,
|
||||
);
|
||||
|
||||
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
|
||||
|
||||
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
|
||||
|
||||
let num_inputs = (key.info.blocks.len() / 128) as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_kreyvium_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
num_inputs,
|
||||
);
|
||||
|
||||
cuda_kreyvium_generate_keystream_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_keystream,
|
||||
&raw const cuda_ffi_key,
|
||||
&raw const cuda_ffi_iv,
|
||||
num_inputs,
|
||||
num_steps,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_kreyvium_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(keystream_output, &cuda_ffi_keystream);
|
||||
}
|
||||
|
||||
187
tfhe/src/integer/gpu/server_key/radix/kreyvium.rs
Normal file
187
tfhe/src/integer/gpu/server_key/radix/kreyvium.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{
|
||||
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
|
||||
};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_kreyvium_generate_keystream, LweBskGroupingFactor, PBSType,
|
||||
};
|
||||
use crate::integer::{RadixCiphertext, RadixClientKey};
|
||||
use crate::shortint::Ciphertext;
|
||||
|
||||
impl RadixClientKey {
|
||||
pub fn encrypt_bits_for_kreyvium(&self, bits: &[u64]) -> RadixCiphertext {
|
||||
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(bits.len());
|
||||
for &bit in bits {
|
||||
let mut ct = self.encrypt(bit);
|
||||
let block = ct.blocks.pop().unwrap();
|
||||
blocks.push(block);
|
||||
}
|
||||
RadixCiphertext::from(blocks)
|
||||
}
|
||||
|
||||
pub fn decrypt_bits_from_kreyvium(&self, encrypted_stream: &RadixCiphertext) -> Vec<u8> {
|
||||
let mut decrypted_bits = Vec::with_capacity(encrypted_stream.blocks.len());
|
||||
for block in &encrypted_stream.blocks {
|
||||
let tmp_radix = RadixCiphertext::from(vec![block.clone()]);
|
||||
let val: u64 = self.decrypt(&tmp_radix);
|
||||
decrypted_bits.push(val as u8);
|
||||
}
|
||||
decrypted_bits
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaServerKey {
|
||||
pub fn kreyvium_generate_keystream(
|
||||
&self,
|
||||
key: &CudaUnsignedRadixCiphertext,
|
||||
iv: &CudaUnsignedRadixCiphertext,
|
||||
num_steps: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext {
|
||||
let num_key_bits = 128;
|
||||
let num_iv_bits = 128;
|
||||
|
||||
assert_eq!(
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
num_key_bits,
|
||||
"Input key must contain {} encrypted bits, but contains {}",
|
||||
num_key_bits,
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
assert_eq!(
|
||||
iv.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
num_iv_bits,
|
||||
"Input IV must contain {} encrypted bits, but contains {}",
|
||||
num_iv_bits,
|
||||
iv.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
let num_output_bits = num_steps;
|
||||
let mut keystream: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_output_bits, streams);
|
||||
|
||||
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
|
||||
panic!("Only the standard atomic pattern is supported on GPU")
|
||||
};
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_kreyvium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_kreyvium_generate_keystream(
|
||||
streams,
|
||||
keystream.as_mut(),
|
||||
key.as_ref(),
|
||||
iv.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&computing_ks_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
computing_ks_key.decomposition_level_count(),
|
||||
computing_ks_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
num_steps as u32,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
keystream
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::{IntegerKeyKind, RadixClientKey};
|
||||
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
#[test]
|
||||
fn test_gpu_kreyvium_correctness() {
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
|
||||
let param = PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
let (raw_cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cpu_cks = RadixClientKey::from((raw_cks, 1));
|
||||
|
||||
let sks = CudaServerKey::new(&cpu_cks, &streams);
|
||||
|
||||
let key_hex = "0053A6F94C9FF24598EB000000000000";
|
||||
let iv_hex = "0D74DB42A91077DE45AC000000000000";
|
||||
let expected_out_hex = "D1F0303482061111";
|
||||
|
||||
let parse_hex = |s: &str| -> Vec<u64> {
|
||||
let mut bits = Vec::new();
|
||||
for i in (0..s.len()).step_by(2) {
|
||||
let byte = u8::from_str_radix(&s[i..i + 2], 16).unwrap();
|
||||
for j in 0..8 {
|
||||
bits.push(((byte >> j) & 1) as u64);
|
||||
}
|
||||
}
|
||||
bits
|
||||
};
|
||||
|
||||
let key_bits = parse_hex(key_hex);
|
||||
let iv_bits = parse_hex(iv_hex);
|
||||
|
||||
assert_eq!(key_bits.len(), 128);
|
||||
assert_eq!(iv_bits.len(), 128);
|
||||
|
||||
let encrypted_key = cpu_cks.encrypt_bits_for_kreyvium(&key_bits);
|
||||
let encrypted_iv = cpu_cks.encrypt_bits_for_kreyvium(&iv_bits);
|
||||
|
||||
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_key, &streams);
|
||||
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&encrypted_iv, &streams);
|
||||
|
||||
let d_keystream = sks.kreyvium_generate_keystream(&d_key, &d_iv, 64, &streams);
|
||||
|
||||
let keystream = d_keystream.to_radix_ciphertext(&streams);
|
||||
let decrypted_bits = cpu_cks.decrypt_bits_from_kreyvium(&keystream);
|
||||
|
||||
let mut result_hex = String::new();
|
||||
for chunk in decrypted_bits.chunks(8) {
|
||||
let mut byte = 0u8;
|
||||
for (j, &b) in chunk.iter().enumerate() {
|
||||
if b == 1 {
|
||||
byte |= 1 << j;
|
||||
}
|
||||
}
|
||||
result_hex.push_str(&format!("{:02X}", byte));
|
||||
}
|
||||
|
||||
assert_eq!(result_hex, expected_out_hex);
|
||||
}
|
||||
}
|
||||
@@ -58,6 +58,7 @@ mod vector_find;
|
||||
|
||||
mod aes;
|
||||
mod aes256;
|
||||
mod kreyvium;
|
||||
#[cfg(test)]
|
||||
mod tests_long_run;
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user