mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(gpu): implement re-randomization
- exposed to integer and HL API - test on the HL API - benchmarks for GPU and CPU implementation
This commit is contained in:
@@ -84,6 +84,7 @@ fn main() {
|
||||
"cuda/include/ciphertext.h",
|
||||
"cuda/include/integer/compression/compression.h",
|
||||
"cuda/include/integer/integer.h",
|
||||
"cuda/include/integer/rerand.h",
|
||||
"cuda/include/aes/aes.h",
|
||||
"cuda/include/zk/zk.h",
|
||||
"cuda/include/keyswitch/keyswitch.h",
|
||||
|
||||
19
backends/tfhe-cuda-backend/cuda/include/integer/rerand.h
Normal file
19
backends/tfhe-cuda-backend/cuda/include/integer/rerand.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "integer.h"
|
||||
|
||||
extern "C" {
|
||||
uint64_t
|
||||
scratch_cuda_rerand_64(CudaStreamsFFI streams, int8_t **mem_ptr,
|
||||
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension,
|
||||
uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, bool allocate_gpu_memory);
|
||||
|
||||
void cuda_rerand_64(
|
||||
CudaStreamsFFI streams, void *lwe_array,
|
||||
const void *lwe_flattened_encryptions_of_zero_compact_array_in,
|
||||
int8_t *mem_ptr, void *const *ksk);
|
||||
|
||||
void cleanup_cuda_rerand(CudaStreamsFFI streams, int8_t **mem_ptr_void);
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include "integer_utilities.h"
|
||||
#include "keyswitch/ks_enums.h"
|
||||
#include "zk/expand.cuh"
|
||||
#include "zk/zk_utilities.h"
|
||||
|
||||
template <typename Torus> struct int_rerand_mem {
|
||||
int_radix_params params;
|
||||
Torus *lwe_trivial_indexes;
|
||||
|
||||
Torus *tmp_zero_lwes;
|
||||
Torus *tmp_ksed_zero_lwes;
|
||||
uint32_t num_lwes;
|
||||
|
||||
bool gpu_memory_allocated;
|
||||
|
||||
expand_job<Torus> *d_expand_jobs;
|
||||
expand_job<Torus> *h_expand_jobs;
|
||||
|
||||
int_rerand_mem(CudaStreams streams, int_radix_params params,
|
||||
const uint32_t num_lwes, const bool allocate_gpu_memory,
|
||||
uint64_t &size_tracker)
|
||||
: params(params), num_lwes(num_lwes),
|
||||
gpu_memory_allocated(allocate_gpu_memory) {
|
||||
|
||||
tmp_zero_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_lwes * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0), size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
tmp_ksed_zero_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_lwes * (params.small_lwe_dimension + 1) * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0), size_tracker,
|
||||
allocate_gpu_memory);
|
||||
|
||||
d_expand_jobs =
|
||||
static_cast<expand_job<Torus> *>(cuda_malloc_with_size_tracking_async(
|
||||
num_lwes * sizeof(expand_job<Torus>), streams.stream(0),
|
||||
streams.gpu_index(0), size_tracker, allocate_gpu_memory));
|
||||
|
||||
h_expand_jobs = static_cast<expand_job<Torus> *>(
|
||||
malloc(num_lwes * sizeof(expand_job<Torus>)));
|
||||
|
||||
auto h_lwe_trivial_indexes =
|
||||
static_cast<Torus *>(malloc(num_lwes * sizeof(Torus)));
|
||||
for (auto i = 0; i < num_lwes; ++i) {
|
||||
h_lwe_trivial_indexes[i] = i;
|
||||
}
|
||||
lwe_trivial_indexes = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_lwes * sizeof(Torus), streams.stream(0), streams.gpu_index(0),
|
||||
size_tracker, allocate_gpu_memory);
|
||||
cuda_memcpy_async_to_gpu(lwe_trivial_indexes, h_lwe_trivial_indexes,
|
||||
num_lwes * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
|
||||
streams.synchronize();
|
||||
|
||||
free(h_lwe_trivial_indexes);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
cuda_drop_with_size_tracking_async(tmp_zero_lwes, streams.stream(0),
|
||||
streams.gpu_index(0),
|
||||
gpu_memory_allocated);
|
||||
cuda_drop_with_size_tracking_async(tmp_ksed_zero_lwes, streams.stream(0),
|
||||
streams.gpu_index(0),
|
||||
gpu_memory_allocated);
|
||||
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
|
||||
streams.gpu_index(0),
|
||||
gpu_memory_allocated);
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
free(h_expand_jobs);
|
||||
}
|
||||
};
|
||||
105
backends/tfhe-cuda-backend/cuda/src/integer/rerand.cu
Normal file
105
backends/tfhe-cuda-backend/cuda/src/integer/rerand.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "rerand.cuh"
|
||||
|
||||
extern "C" {
|
||||
uint64_t
|
||||
scratch_cuda_rerand_64(CudaStreamsFFI streams, int8_t **mem_ptr,
|
||||
uint32_t big_lwe_dimension, uint32_t small_lwe_dimension,
|
||||
uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t lwe_ciphertext_count, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, bool allocate_gpu_memory) {
|
||||
PUSH_RANGE("scratch rerand")
|
||||
int_radix_params params(PBS_TYPE::CLASSICAL, 0, 0, big_lwe_dimension,
|
||||
small_lwe_dimension, ks_level, ks_base_log, 0, 0, 0,
|
||||
message_modulus, carry_modulus,
|
||||
PBS_MS_REDUCTION_T::NO_REDUCTION);
|
||||
|
||||
uint64_t ret = scratch_cuda_rerand<uint64_t>(
|
||||
CudaStreams(streams), (int_rerand_mem<uint64_t> **)mem_ptr,
|
||||
lwe_ciphertext_count, params, allocate_gpu_memory);
|
||||
POP_RANGE()
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* Executes the re-randomization procedure, adding encryptions of zero to each
|
||||
* element of an array of LWE ciphertexts. This method expects the encryptions
|
||||
* of zero to be provided as input in the format of a flattened compact
|
||||
* ciphertext list, generated using a compact public key.
|
||||
*/
|
||||
void cuda_rerand_64(
|
||||
CudaStreamsFFI streams, void *lwe_array,
|
||||
const void *lwe_flattened_encryptions_of_zero_compact_array_in,
|
||||
int8_t *mem_ptr, void *const *ksk) {
|
||||
|
||||
auto rerand_buffer = reinterpret_cast<int_rerand_mem<uint64_t> *>(mem_ptr);
|
||||
|
||||
switch (rerand_buffer->params.big_lwe_dimension) {
|
||||
case 256:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<256>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 512:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<512>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 1024:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<1024>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 2048:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<2048>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 4096:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<4096>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 8192:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<8192>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
case 16384:
|
||||
rerand_inplace<uint64_t, AmortizedDegree<16384>>(
|
||||
streams, static_cast<uint64_t *>(lwe_array),
|
||||
static_cast<const uint64_t *>(
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
(uint64_t **)(ksk), rerand_buffer);
|
||||
break;
|
||||
default:
|
||||
PANIC("CUDA error: lwe_dimension not supported."
|
||||
"Supported n's are powers of two"
|
||||
" in the interval [256..16384].");
|
||||
break;
|
||||
}
|
||||
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(streams.streams[0]),
|
||||
streams.gpu_indexes[0]);
|
||||
}
|
||||
|
||||
void cleanup_cuda_rerand(CudaStreamsFFI streams, int8_t **mem_ptr_void) {
|
||||
PUSH_RANGE("cleanup rerand")
|
||||
int_rerand_mem<uint64_t> *mem_ptr =
|
||||
(int_rerand_mem<uint64_t> *)(*mem_ptr_void);
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
POP_RANGE()
|
||||
}
|
||||
}
|
||||
87
backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh
Normal file
87
backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh
Normal file
@@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
|
||||
#include "device.h"
|
||||
#include "integer/integer.h"
|
||||
#include "integer/radix_ciphertext.h"
|
||||
#include "integer/rerand.h"
|
||||
#include "integer/rerand_utilities.h"
|
||||
#include "utils/helper_profile.cuh"
|
||||
#include "utils/kernel_dimensions.cuh"
|
||||
#include "zk/zk_utilities.h"
|
||||
|
||||
template <typename Torus, class params>
|
||||
void rerand_inplace(
|
||||
CudaStreams const streams, Torus *lwe_array,
|
||||
const Torus *lwe_flattened_encryptions_of_zero_compact_array_in,
|
||||
Torus *const *ksk, int_rerand_mem<Torus> *mem_ptr) {
|
||||
auto zero_lwes = mem_ptr->tmp_zero_lwes;
|
||||
auto num_lwes = mem_ptr->num_lwes;
|
||||
auto ksed_zero_lwes = mem_ptr->tmp_ksed_zero_lwes;
|
||||
auto lwe_trivial_indexes = mem_ptr->lwe_trivial_indexes;
|
||||
auto ksk_params = mem_ptr->params;
|
||||
auto output_dimension = ksk_params.small_lwe_dimension;
|
||||
auto input_dimension = ksk_params.big_lwe_dimension;
|
||||
auto ks_level = ksk_params.ks_level;
|
||||
auto ks_base_log = ksk_params.ks_base_log;
|
||||
auto message_modulus = ksk_params.message_modulus;
|
||||
auto carry_modulus = ksk_params.carry_modulus;
|
||||
|
||||
GPU_ASSERT(sizeof(Torus) == 8,
|
||||
"Cuda error: expand is only supported on 64 bits");
|
||||
|
||||
// Expand encryptions of zero
|
||||
// Wraps the input into a flattened_compact_lwe_lists type
|
||||
auto compact_lwe_lists = flattened_compact_lwe_lists<Torus>(
|
||||
const_cast<Torus *>(lwe_flattened_encryptions_of_zero_compact_array_in),
|
||||
&num_lwes, (uint32_t)1, input_dimension);
|
||||
auto h_expand_jobs = mem_ptr->h_expand_jobs;
|
||||
auto d_expand_jobs = mem_ptr->d_expand_jobs;
|
||||
|
||||
auto output_index = 0;
|
||||
for (auto list_index = 0; list_index < compact_lwe_lists.num_compact_lists;
|
||||
++list_index) {
|
||||
auto list = compact_lwe_lists.get_device_compact_list(list_index);
|
||||
for (auto lwe_index = 0; lwe_index < list.total_num_lwes; ++lwe_index) {
|
||||
h_expand_jobs[output_index] =
|
||||
expand_job<Torus>(list.get_mask(), list.get_body(lwe_index));
|
||||
output_index++;
|
||||
}
|
||||
}
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
d_expand_jobs, h_expand_jobs,
|
||||
compact_lwe_lists.total_num_lwes * sizeof(expand_job<Torus>),
|
||||
streams.stream(0), streams.gpu_index(0), true);
|
||||
|
||||
host_lwe_expand<Torus, params>(streams.stream(0), streams.gpu_index(0),
|
||||
zero_lwes, d_expand_jobs, num_lwes);
|
||||
|
||||
// Keyswitch
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), ksed_zero_lwes, lwe_trivial_indexes, zero_lwes,
|
||||
lwe_trivial_indexes, ksk, input_dimension, output_dimension, ks_base_log,
|
||||
ks_level, num_lwes);
|
||||
|
||||
// Add ks output to ct
|
||||
// Check sizes
|
||||
auto lwes_ffi = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(lwes_ffi, lwe_array, num_lwes, output_dimension);
|
||||
auto ksed_zero_lwes_ffi = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(ksed_zero_lwes_ffi, ksed_zero_lwes, num_lwes,
|
||||
output_dimension);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), lwes_ffi,
|
||||
lwes_ffi, ksed_zero_lwes_ffi, num_lwes, message_modulus,
|
||||
carry_modulus);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ uint64_t scratch_cuda_rerand(CudaStreams streams,
|
||||
int_rerand_mem<Torus> **mem_ptr,
|
||||
uint32_t num_lwes,
|
||||
int_radix_params params,
|
||||
bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_rerand_mem<Torus>(streams, params, num_lwes,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
return size_tracker;
|
||||
}
|
||||
@@ -1755,6 +1755,32 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_rerand_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
lwe_ciphertext_count: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
allocate_gpu_memory: bool,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_rerand_64(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array: *mut ffi::c_void,
|
||||
lwe_flattened_encryptions_of_zero_compact_array_in: *const ffi::c_void,
|
||||
mem_ptr: *mut i8,
|
||||
ksk: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_rerand(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_aes_encrypt_64(
|
||||
streams: CudaStreamsFFI,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "cuda/include/ciphertext.h"
|
||||
#include "cuda/include/integer/compression/compression.h"
|
||||
#include "cuda/include/integer/integer.h"
|
||||
#include "cuda/include/integer/rerand.h"
|
||||
#include "cuda/include/aes/aes.h"
|
||||
#include "cuda/include/zk/zk.h"
|
||||
#include "cuda/include/keyswitch/keyswitch.h"
|
||||
|
||||
@@ -96,6 +96,12 @@ path = "benches/integer/glwe_packing_compression.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "pbs-stats", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "integer-rerand"
|
||||
path = "benches/integer/rerand.rs"
|
||||
harness = false
|
||||
required-features = ["integer", "pbs-stats", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "glwe_packing_compression_128b-integer-bench"
|
||||
path = "benches/integer/glwe_packing_compression_128b.rs"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
mod aes;
|
||||
mod oprf;
|
||||
|
||||
mod rerand;
|
||||
|
||||
use benchmark::params::ParamsAndNumBlocksIter;
|
||||
use benchmark::utilities::{
|
||||
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, EnvConfig, OperatorType,
|
||||
|
||||
417
tfhe-benchmark/benches/integer/rerand.rs
Normal file
417
tfhe-benchmark/benches/integer/rerand.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
use benchmark::params_aliases::{
|
||||
BENCH_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
|
||||
};
|
||||
use benchmark::utilities::{
|
||||
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, OperatorType,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, BatchSize, Criterion, Throughput};
|
||||
#[cfg(feature = "gpu")]
|
||||
use cuda::gpu_re_randomize_group;
|
||||
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
|
||||
use rayon::prelude::{IntoParallelIterator, IntoParallelRefMutIterator};
|
||||
use tfhe::integer::ciphertext::{CompressedCiphertextListBuilder, ReRandomizationContext};
|
||||
use tfhe::integer::key_switching_key::{KeySwitchingKey, KeySwitchingKeyMaterial};
|
||||
use tfhe::integer::{gen_keys_radix, CompactPrivateKey, CompactPublicKey, RadixCiphertext};
|
||||
use tfhe::keycache::NamedParam;
|
||||
|
||||
fn execute_cpu_re_randomize(c: &mut Criterion, bit_size: usize) {
|
||||
let bench_name = "integer::re_randomize";
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(30));
|
||||
|
||||
let param = BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let comp_param = BENCH_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let cpk_param = BENCH_PARAM_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1;
|
||||
let ks_param = BENCH_PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
let num_blocks = (bit_size as f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
let (radix_cks, sks) = gen_keys_radix(param, num_blocks);
|
||||
let cks = radix_cks.as_ref();
|
||||
|
||||
let private_compression_key = cks.new_compression_private_key(comp_param);
|
||||
let (compressed_compression_key, compressed_decompression_key) =
|
||||
radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
|
||||
|
||||
let compression_key = compressed_compression_key.decompress();
|
||||
let decompression_key = compressed_decompression_key.decompress();
|
||||
|
||||
let cpk_private_key = CompactPrivateKey::new(cpk_param);
|
||||
let cpk = CompactPublicKey::new(&cpk_private_key);
|
||||
let ksk = KeySwitchingKey::new((&cpk_private_key, None), ((&cks), (&sks)), ks_param);
|
||||
let ksk = ksk.into_raw_parts();
|
||||
let (ksk_material, _, _) = ksk.into_raw_parts();
|
||||
let ksk_material = KeySwitchingKeyMaterial::from_raw_parts(ksk_material);
|
||||
|
||||
let rerand_domain_separator = *b"TFHE_Rrd";
|
||||
let compact_public_encryption_domain_separator = *b"TFHE_Enc";
|
||||
let metadata = b"bench".as_slice();
|
||||
|
||||
let bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
// Encrypt and compress a single ciphertext
|
||||
let message = 42u64;
|
||||
let ct = cks.encrypt_radix(message, num_blocks);
|
||||
|
||||
let mut builder = CompressedCiphertextListBuilder::new();
|
||||
builder.push(ct);
|
||||
let compressed = builder.build(&compression_key);
|
||||
let decompressed: RadixCiphertext =
|
||||
compressed.get(0, &decompression_key).unwrap().unwrap();
|
||||
|
||||
let mut d_re_randomized = decompressed.clone();
|
||||
|
||||
bench_id = format!("{bench_name}::latency_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let mut re_randomizer_context = ReRandomizationContext::new(
|
||||
rerand_domain_separator,
|
||||
[metadata],
|
||||
compact_public_encryption_domain_separator,
|
||||
);
|
||||
|
||||
re_randomizer_context.add_ciphertext(&decompressed);
|
||||
re_randomizer_context.finalize()
|
||||
},
|
||||
|mut seed_gen| {
|
||||
d_re_randomized
|
||||
.re_randomize(
|
||||
&cpk,
|
||||
&ksk_material.as_view(),
|
||||
seed_gen.next_seed().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
_ = black_box(&d_re_randomized);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = throughput_num_threads(num_blocks, 1);
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
|
||||
// Pre-generate and compress ciphertexts for throughput test
|
||||
let decompressed_cts: Vec<RadixCiphertext> = (0..elements as usize)
|
||||
.into_par_iter()
|
||||
.map(|_| {
|
||||
let message = 42u64;
|
||||
let ct = cks.encrypt_radix(message, num_blocks);
|
||||
|
||||
let mut builder = CompressedCiphertextListBuilder::new();
|
||||
builder.push(ct);
|
||||
let compressed = builder.build(&compression_key);
|
||||
|
||||
compressed.get(0, &decompression_key).unwrap().unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
bench_id = format!("{bench_name}::throughput_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Create a fresh context for each benchmark iteration
|
||||
let mut ctx = ReRandomizationContext::new(
|
||||
rerand_domain_separator,
|
||||
[metadata],
|
||||
compact_public_encryption_domain_separator,
|
||||
);
|
||||
|
||||
// Add all ciphertexts to the context
|
||||
for ct in &decompressed_cts {
|
||||
ctx.add_ciphertext(ct);
|
||||
}
|
||||
|
||||
// Return a new seed generator for this iteration
|
||||
(ctx.finalize(), decompressed_cts.clone())
|
||||
},
|
||||
|(mut seed_gen, mut cts_to_rerand)| {
|
||||
let seeds: Vec<_> = (0..cts_to_rerand.len())
|
||||
.map(|_| seed_gen.next_seed().unwrap())
|
||||
.collect();
|
||||
|
||||
cts_to_rerand
|
||||
.par_iter_mut()
|
||||
.zip(seeds.into_par_iter())
|
||||
.for_each(|(d_re_randomized, seed)| {
|
||||
d_re_randomized
|
||||
.re_randomize(&cpk, &ksk_material.as_view(), seed)
|
||||
.unwrap();
|
||||
|
||||
_ = black_box(&d_re_randomized);
|
||||
})
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
(comp_param, param.into()),
|
||||
comp_param.name(),
|
||||
"re_randomize",
|
||||
&OperatorType::Atomic,
|
||||
bit_size as u32,
|
||||
vec![param.message_modulus.0.ilog2(); num_blocks],
|
||||
);
|
||||
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn cpu_re_randomize(c: &mut Criterion) {
|
||||
let bit_sizes = [2, 4, 8, 16, 32, 64, 128, 256];
|
||||
|
||||
for bit_size in bit_sizes.iter() {
|
||||
execute_cpu_re_randomize(c, *bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(cpu_re_randomize_group, cpu_re_randomize);
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
mod cuda {
|
||||
use benchmark::params_aliases::{
|
||||
BENCH_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
BENCH_PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1,
|
||||
};
|
||||
use benchmark::utilities::cuda_integer_utils::cuda_local_streams;
|
||||
use benchmark::utilities::{
|
||||
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, OperatorType,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, BatchSize, Criterion, Throughput};
|
||||
use rayon::prelude::*;
|
||||
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
|
||||
use tfhe::integer::ciphertext::ReRandomizationContext;
|
||||
use tfhe::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextListBuilder;
|
||||
use tfhe::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use tfhe::integer::key_switching_key::KeySwitchingKey;
|
||||
use tfhe::integer::{gen_keys_radix, CompactPrivateKey, CompactPublicKey};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
|
||||
fn execute_gpu_re_randomize(c: &mut Criterion, bit_size: usize) {
|
||||
let bench_name = "integer::cuda::re_randomize";
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(30));
|
||||
|
||||
let param = BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
|
||||
let comp_param = BENCH_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let cpk_param = BENCH_PARAM_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128_ZKV1;
|
||||
let ks_param = BENCH_PARAM_KEYSWITCH_PKE_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
|
||||
let num_blocks =
|
||||
(bit_size as f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
let (radix_cks, sks) = gen_keys_radix(param, num_blocks);
|
||||
let cks = radix_cks.as_ref();
|
||||
|
||||
let private_compression_key = cks.new_compression_private_key(comp_param);
|
||||
let (cuda_compression_key, cuda_decompression_key) =
|
||||
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);
|
||||
|
||||
let cpk_private_key = CompactPrivateKey::new(cpk_param);
|
||||
let cpk = CompactPublicKey::new(&cpk_private_key);
|
||||
let ksk = KeySwitchingKey::new((&cpk_private_key, None), (&cks, &sks), ks_param);
|
||||
let d_ksk_material = CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
|
||||
|
||||
let rerand_domain_separator = *b"TFHE_Rrd";
|
||||
let compact_public_encryption_domain_separator = *b"TFHE_Enc";
|
||||
let metadata = b"bench".as_slice();
|
||||
|
||||
let bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
// Encrypt and compress a single ciphertext
|
||||
let message = 42u64;
|
||||
let ct = cks.encrypt_radix(message, num_blocks);
|
||||
let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
|
||||
|
||||
let mut builder = CudaCompressedCiphertextListBuilder::new();
|
||||
builder.push(d_ct, &streams);
|
||||
let compressed = builder.build(&cuda_compression_key, &streams);
|
||||
let d_decompressed: CudaUnsignedRadixCiphertext = compressed
|
||||
.get(0, &cuda_decompression_key, &streams)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
|
||||
|
||||
let mut d_re_randomized = d_decompressed.duplicate(&streams);
|
||||
|
||||
bench_id = format!("{bench_name}::latency_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let mut re_randomizer_context = ReRandomizationContext::new(
|
||||
rerand_domain_separator,
|
||||
[metadata],
|
||||
compact_public_encryption_domain_separator,
|
||||
);
|
||||
|
||||
re_randomizer_context.add_ciphertext(&decompressed);
|
||||
re_randomizer_context.finalize()
|
||||
},
|
||||
|mut seed_gen| {
|
||||
d_re_randomized
|
||||
.re_randomize(
|
||||
&cpk,
|
||||
&d_ksk_material,
|
||||
seed_gen.next_seed().unwrap(),
|
||||
&streams,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
_ = black_box(&d_re_randomized);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = throughput_num_threads(num_blocks, 1);
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
|
||||
let local_streams = cuda_local_streams(num_blocks, elements as usize);
|
||||
let num_gpus = get_number_of_gpus() as usize;
|
||||
|
||||
let d_ksk_material_vec: Vec<CudaKeySwitchingKeyMaterial> = (0..num_gpus)
|
||||
.map(|i| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, local_stream)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Pre-generate and compress ciphertexts for throughput test
|
||||
let d_compressed_cts: Vec<CudaUnsignedRadixCiphertext> = (0..elements as usize)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let message = 42u64;
|
||||
let ct = cks.encrypt_radix(message, num_blocks);
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let d_ct =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, local_stream);
|
||||
|
||||
let mut builder = CudaCompressedCiphertextListBuilder::new();
|
||||
builder.push(d_ct, local_stream);
|
||||
let compressed = builder.build(&cuda_compression_key, local_stream);
|
||||
|
||||
compressed
|
||||
.get(0, &cuda_decompression_key, local_stream)
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Prepare decompressed ciphertexts once
|
||||
let h_decompressed_cts: Vec<_> = d_compressed_cts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, d_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
d_ct.to_radix_ciphertext(local_stream)
|
||||
})
|
||||
.collect();
|
||||
|
||||
bench_id = format!("{bench_name}::throughput_u{bit_size}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Create a fresh context for each benchmark iteration
|
||||
let mut ctx = ReRandomizationContext::new(
|
||||
rerand_domain_separator,
|
||||
[metadata],
|
||||
compact_public_encryption_domain_separator,
|
||||
);
|
||||
|
||||
// Add all ciphertexts to the context
|
||||
for ct in &h_decompressed_cts {
|
||||
ctx.add_ciphertext(ct);
|
||||
}
|
||||
|
||||
let d_cts_to_rerand = d_compressed_cts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, d_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
d_ct.duplicate(local_stream)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Return a new seed generator for this iteration
|
||||
(ctx.finalize(), h_decompressed_cts.clone(), d_cts_to_rerand)
|
||||
},
|
||||
|(mut seed_gen, h_cts_to_rerand, mut d_cts_to_rerand)| {
|
||||
let seeds: Vec<_> = (0..h_cts_to_rerand.len())
|
||||
.map(|_| seed_gen.next_seed().unwrap())
|
||||
.collect();
|
||||
|
||||
d_cts_to_rerand
|
||||
.par_iter_mut()
|
||||
.zip(seeds.into_par_iter())
|
||||
.enumerate()
|
||||
.for_each(|(i, (d_re_randomized, seed))| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let d_ksk = &d_ksk_material_vec[i % num_gpus];
|
||||
|
||||
d_re_randomized
|
||||
.re_randomize(&cpk, d_ksk, seed, local_stream)
|
||||
.unwrap();
|
||||
|
||||
_ = black_box(&d_re_randomized);
|
||||
})
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
(comp_param, param.into()),
|
||||
comp_param.name(),
|
||||
"re_randomize",
|
||||
&OperatorType::Atomic,
|
||||
bit_size as u32,
|
||||
vec![param.message_modulus.0.ilog2(); num_blocks],
|
||||
);
|
||||
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn gpu_re_randomize(c: &mut Criterion) {
|
||||
let bit_sizes = [2, 4, 16, 32, 64, 128, 256];
|
||||
|
||||
for bit_size in bit_sizes.iter() {
|
||||
execute_gpu_re_randomize(c, *bit_size);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(gpu_re_randomize_group, gpu_re_randomize);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu_re_randomize_group();
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
cpu_re_randomize_group();
|
||||
Criterion::default().configure_from_args().final_summary();
|
||||
}
|
||||
@@ -536,8 +536,8 @@ pub enum KeySwitchingKeyVersions {
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum ReRandomizationKeySwitchingKeyVersions {
|
||||
V0(ReRandomizationKeySwitchingKey),
|
||||
pub enum ReRandomizationKeySwitchingKeyVersions<KSK> {
|
||||
V0(ReRandomizationKeySwitchingKey<KSK>),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
|
||||
@@ -2382,7 +2382,23 @@ impl ReRandomize for FheBool {
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => panic!("GPU does not support CPKReRandomize."),
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let Some(re_randomization_key) = cuda_key.re_randomization_cpk_casting_key() else {
|
||||
return Err(UninitializedReRandKey.into());
|
||||
};
|
||||
|
||||
let streams = &cuda_key.streams;
|
||||
self.ciphertext.as_gpu_mut(streams).re_randomize(
|
||||
&compact_public_key.key.key,
|
||||
re_randomization_key,
|
||||
seed,
|
||||
streams,
|
||||
)?;
|
||||
|
||||
self.re_randomization_metadata_mut().clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("HPU does not support CPKReRandomize.")
|
||||
|
||||
@@ -1198,7 +1198,23 @@ where
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => panic!("GPU does not support CPKReRandomize."),
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let Some(re_randomization_key) = cuda_key.re_randomization_cpk_casting_key() else {
|
||||
return Err(UninitializedReRandKey.into());
|
||||
};
|
||||
|
||||
let streams = &cuda_key.streams;
|
||||
self.ciphertext.as_gpu_mut(streams).re_randomize(
|
||||
&compact_public_key.key.key,
|
||||
re_randomization_key,
|
||||
seed,
|
||||
streams,
|
||||
)?;
|
||||
|
||||
self.re_randomization_metadata_mut().clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("HPU does not support CPKReRandomize.")
|
||||
|
||||
@@ -1708,7 +1708,23 @@ where
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => panic!("GPU does not support CPKReRandomize."),
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let Some(re_randomization_key) = cuda_key.re_randomization_cpk_casting_key() else {
|
||||
return Err(UninitializedReRandKey.into());
|
||||
};
|
||||
|
||||
let streams = &cuda_key.streams;
|
||||
self.ciphertext.as_gpu_mut(streams).re_randomize(
|
||||
&compact_public_key.key.key,
|
||||
re_randomization_key,
|
||||
seed,
|
||||
streams,
|
||||
)?;
|
||||
|
||||
self.re_randomization_metadata_mut().clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("HPU does not support CPKReRandomize.")
|
||||
|
||||
@@ -17,9 +17,9 @@ pub(crate) enum ReRandomizationKeyGenerationInfo<'a> {
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
#[versionize(ReRandomizationKeySwitchingKeyVersions)]
|
||||
pub enum ReRandomizationKeySwitchingKey {
|
||||
pub enum ReRandomizationKeySwitchingKey<KSK> {
|
||||
UseCPKEncryptionKSK,
|
||||
DedicatedKSK(crate::integer::key_switching_key::KeySwitchingKeyMaterial),
|
||||
DedicatedKSK(KSK),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
@@ -30,7 +30,10 @@ pub enum CompressedReRandomizationKeySwitchingKey {
|
||||
}
|
||||
|
||||
impl CompressedReRandomizationKeySwitchingKey {
|
||||
pub fn decompress(&self) -> ReRandomizationKeySwitchingKey {
|
||||
pub fn decompress(
|
||||
&self,
|
||||
) -> ReRandomizationKeySwitchingKey<crate::integer::key_switching_key::KeySwitchingKeyMaterial>
|
||||
{
|
||||
match self {
|
||||
Self::UseCPKEncryptionKSK => ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK,
|
||||
Self::DedicatedKSK(compressed_key_switching_key_material) => {
|
||||
|
||||
@@ -353,8 +353,9 @@ pub struct IntegerServerKey {
|
||||
pub(crate) decompression_key: Option<DecompressionKey>,
|
||||
pub(crate) noise_squashing_key: Option<NoiseSquashingKey>,
|
||||
pub(crate) noise_squashing_compression_key: Option<NoiseSquashingCompressionKey>,
|
||||
pub(crate) cpk_re_randomization_key_switching_key_material:
|
||||
Option<ReRandomizationKeySwitchingKey>,
|
||||
pub(crate) cpk_re_randomization_key_switching_key_material: Option<
|
||||
ReRandomizationKeySwitchingKey<crate::integer::key_switching_key::KeySwitchingKeyMaterial>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl IntegerServerKey {
|
||||
@@ -489,6 +490,29 @@ pub struct IntegerCudaServerKey {
|
||||
pub(crate) noise_squashing_compression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaNoiseSquashingCompressionKey,
|
||||
>,
|
||||
pub(crate) cpk_re_randomization_key_switching_key_material: Option<
|
||||
ReRandomizationKeySwitchingKey<
|
||||
crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl IntegerCudaServerKey {
|
||||
pub(in crate::high_level_api) fn re_randomization_cpk_casting_key(
|
||||
&self,
|
||||
) -> Option<&crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial> {
|
||||
self.cpk_re_randomization_key_switching_key_material
|
||||
.as_ref()
|
||||
.and_then(|key| match key {
|
||||
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
|
||||
self.cpk_key_switching_key_material.as_ref()
|
||||
}
|
||||
ReRandomizationKeySwitchingKey::DedicatedKSK(key_switching_key_material) => {
|
||||
Some(key_switching_key_material)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
|
||||
@@ -68,7 +68,11 @@ impl ServerKey {
|
||||
Option<DecompressionKey>,
|
||||
Option<NoiseSquashingKey>,
|
||||
Option<NoiseSquashingCompressionKey>,
|
||||
Option<ReRandomizationKeySwitchingKey>,
|
||||
Option<
|
||||
ReRandomizationKeySwitchingKey<
|
||||
crate::integer::key_switching_key::KeySwitchingKeyMaterial,
|
||||
>,
|
||||
>,
|
||||
Tag,
|
||||
) {
|
||||
let IntegerServerKey {
|
||||
@@ -103,7 +107,11 @@ impl ServerKey {
|
||||
decompression_key: Option<DecompressionKey>,
|
||||
noise_squashing_key: Option<NoiseSquashingKey>,
|
||||
noise_squashing_compression_key: Option<NoiseSquashingCompressionKey>,
|
||||
cpk_re_randomization_key_switching_key_material: Option<ReRandomizationKeySwitchingKey>,
|
||||
cpk_re_randomization_key_switching_key_material: Option<
|
||||
ReRandomizationKeySwitchingKey<
|
||||
crate::integer::key_switching_key::KeySwitchingKeyMaterial,
|
||||
>,
|
||||
>,
|
||||
tag: Tag,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -349,9 +357,36 @@ impl CompressedServerKey {
|
||||
CudaKeySwitchingKeyMaterial {
|
||||
lwe_keyswitch_key: d_ksk,
|
||||
destination_key: ksk_material.material.destination_key,
|
||||
cast_rshift: ksk_material.material.cast_rshift,
|
||||
}
|
||||
});
|
||||
|
||||
let cpk_re_randomization_key_switching_key_material = self
|
||||
.integer_key
|
||||
.cpk_re_randomization_key_switching_key_material
|
||||
.as_ref()
|
||||
.map(
|
||||
|cpk_re_randomization_ksk_material| match cpk_re_randomization_ksk_material {
|
||||
CompressedReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
|
||||
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK
|
||||
}
|
||||
CompressedReRandomizationKeySwitchingKey::DedicatedKSK(dedicated_ksk) => {
|
||||
let ksk_material = dedicated_ksk.decompress();
|
||||
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
&ksk_material.material.key_switching_key,
|
||||
&streams,
|
||||
);
|
||||
let d_ksk_material = CudaKeySwitchingKeyMaterial {
|
||||
lwe_keyswitch_key: d_ksk,
|
||||
destination_key: ksk_material.material.destination_key,
|
||||
cast_rshift: ksk_material.material.cast_rshift,
|
||||
};
|
||||
|
||||
ReRandomizationKeySwitchingKey::DedicatedKSK(d_ksk_material)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let compression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaCompressionKey,
|
||||
> = self
|
||||
@@ -411,6 +446,7 @@ impl CompressedServerKey {
|
||||
decompression_key,
|
||||
noise_squashing_key,
|
||||
noise_squashing_compression_key,
|
||||
cpk_re_randomization_key_switching_key_material,
|
||||
}),
|
||||
tag: self.tag.clone(),
|
||||
streams,
|
||||
@@ -453,6 +489,11 @@ impl CudaServerKey {
|
||||
pub fn gpu_indexes(&self) -> &[GpuIndex] {
|
||||
&self.key.key.key_switching_key.d_vec.gpu_indexes
|
||||
}
|
||||
pub(in crate::high_level_api) fn re_randomization_cpk_casting_key(
|
||||
&self,
|
||||
) -> Option<&CudaKeySwitchingKeyMaterial> {
|
||||
self.key.re_randomization_cpk_casting_key()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
|
||||
@@ -1,47 +1,23 @@
|
||||
use crate::high_level_api::prelude::*;
|
||||
use crate::high_level_api::{
|
||||
generate_keys, CompactPublicKey, CompressedCiphertextListBuilder, ConfigBuilder, FheBool,
|
||||
FheInt8, FheUint64, ReRandomizationContext,
|
||||
};
|
||||
use crate::set_server_key;
|
||||
use crate::shortint::parameters::{
|
||||
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_KEYSWITCH_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
CompactPublicKey, CompressedCiphertextListBuilder, ConfigBuilder, FheBool, FheInt8, FheUint64,
|
||||
ReRandomizationContext,
|
||||
};
|
||||
use crate::shortint::parameters::v1_5::meta::cpu::V1_5_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::shortint::parameters::v1_5::meta::gpu::V1_5_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128;
|
||||
use crate::shortint::parameters::{MetaParameters, ShortintKeySwitchingParameters};
|
||||
use crate::{set_server_key, ClientKey, CompressedServerKey};
|
||||
|
||||
#[test]
|
||||
fn test_re_rand() {
|
||||
let params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let cpk_params = (
|
||||
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
);
|
||||
let comp_params = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
let re_rand_ks_params = PARAM_KEYSWITCH_TO_BIG_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
let config = ConfigBuilder::with_custom_parameters(params)
|
||||
.use_dedicated_compact_public_key_parameters(cpk_params)
|
||||
.enable_compression(comp_params)
|
||||
.enable_ciphertext_re_randomization(re_rand_ks_params)
|
||||
.build();
|
||||
|
||||
let (cks, sks) = generate_keys(config);
|
||||
let cpk = CompactPublicKey::new(&cks);
|
||||
|
||||
fn execute_re_rand_test(cks: &ClientKey, cpk: &CompactPublicKey) {
|
||||
let compact_public_encryption_domain_separator = *b"TFHE_Enc";
|
||||
let rerand_domain_separator = *b"TFHE_Rrd";
|
||||
|
||||
set_server_key(sks);
|
||||
|
||||
// Case where we want to compute FheUint64 + FheUint64 and re-randomize those inputs
|
||||
{
|
||||
let clear_a = rand::random::<u64>();
|
||||
let clear_b = rand::random::<u64>();
|
||||
let mut a = FheUint64::encrypt(clear_a, &cks);
|
||||
let mut b = FheUint64::encrypt(clear_b, &cks);
|
||||
let mut a = FheUint64::encrypt(clear_a, cks);
|
||||
let mut b = FheUint64::encrypt(clear_b, cks);
|
||||
|
||||
// Simulate a 256 bits hash added as metadata
|
||||
let rand_a: [u8; 256 / 8] = core::array::from_fn(|_| rand::random());
|
||||
@@ -77,14 +53,14 @@ fn test_re_rand() {
|
||||
|
||||
let mut seed_gen = re_rand_context.finalize();
|
||||
|
||||
a.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
a.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(a.re_randomization_metadata().data().is_empty());
|
||||
|
||||
b.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
b.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(b.re_randomization_metadata().data().is_empty());
|
||||
|
||||
let c = a + b;
|
||||
let dec: u64 = c.decrypt(&cks);
|
||||
let dec: u64 = c.decrypt(cks);
|
||||
|
||||
assert_eq!(clear_a.wrapping_add(clear_b), dec);
|
||||
}
|
||||
@@ -93,8 +69,8 @@ fn test_re_rand() {
|
||||
{
|
||||
let clear_a = rand::random::<i8>();
|
||||
let clear_b = rand::random::<i8>();
|
||||
let mut a = FheInt8::encrypt(clear_a, &cks);
|
||||
let mut b = FheInt8::encrypt(clear_b, &cks);
|
||||
let mut a = FheInt8::encrypt(clear_a, cks);
|
||||
let mut b = FheInt8::encrypt(clear_b, cks);
|
||||
|
||||
// Simulate a 256 bits hash added as metadata
|
||||
let rand_a: [u8; 256 / 8] = core::array::from_fn(|_| rand::random());
|
||||
@@ -131,14 +107,14 @@ fn test_re_rand() {
|
||||
|
||||
let mut seed_gen = re_rand_context.finalize();
|
||||
|
||||
a.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
a.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(a.re_randomization_metadata().data().is_empty());
|
||||
|
||||
b.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
b.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(b.re_randomization_metadata().data().is_empty());
|
||||
|
||||
let c = a + b;
|
||||
let dec: i8 = c.decrypt(&cks);
|
||||
let dec: i8 = c.decrypt(cks);
|
||||
|
||||
assert_eq!(clear_a.wrapping_add(clear_b), dec);
|
||||
}
|
||||
@@ -147,8 +123,8 @@ fn test_re_rand() {
|
||||
{
|
||||
for clear_a in [false, true] {
|
||||
for clear_b in [false, true] {
|
||||
let mut a = FheBool::encrypt(clear_a, &cks);
|
||||
let mut b = FheBool::encrypt(clear_b, &cks);
|
||||
let mut a = FheBool::encrypt(clear_a, cks);
|
||||
let mut b = FheBool::encrypt(clear_b, cks);
|
||||
|
||||
// Simulate a 256 bits hash added as metadata
|
||||
let rand_a: [u8; 256 / 8] = core::array::from_fn(|_| rand::random());
|
||||
@@ -184,17 +160,72 @@ fn test_re_rand() {
|
||||
|
||||
let mut seed_gen = re_rand_context.finalize();
|
||||
|
||||
a.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
a.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(a.re_randomization_metadata().data().is_empty());
|
||||
|
||||
b.re_randomize(&cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
b.re_randomize(cpk, seed_gen.next_seed().unwrap()).unwrap();
|
||||
assert!(b.re_randomization_metadata().data().is_empty());
|
||||
|
||||
let c = a & b;
|
||||
let dec: bool = c.decrypt(&cks);
|
||||
let dec: bool = c.decrypt(cks);
|
||||
|
||||
assert_eq!(clear_a && clear_b, dec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_re_rand_test(
|
||||
params: MetaParameters,
|
||||
) -> (crate::ClientKey, CompressedServerKey, CompactPublicKey) {
|
||||
let cpk_params = (
|
||||
params
|
||||
.dedicated_compact_public_key_parameters
|
||||
.unwrap()
|
||||
.pke_params,
|
||||
params
|
||||
.dedicated_compact_public_key_parameters
|
||||
.unwrap()
|
||||
.ksk_params,
|
||||
);
|
||||
let comp_params = params.compression_parameters.unwrap();
|
||||
let compute_params = params.compute_parameters;
|
||||
let ksk_params = ShortintKeySwitchingParameters::new(
|
||||
compute_params.ks_base_log(),
|
||||
compute_params.ks_level(),
|
||||
compute_params.encryption_key_choice(),
|
||||
);
|
||||
|
||||
let config = ConfigBuilder::with_custom_parameters(compute_params)
|
||||
.use_dedicated_compact_public_key_parameters(cpk_params)
|
||||
.enable_compression(comp_params)
|
||||
.enable_ciphertext_re_randomization(ksk_params)
|
||||
.build();
|
||||
|
||||
let cks = crate::ClientKey::generate(config);
|
||||
let sks = cks.generate_compressed_server_key();
|
||||
let cpk = CompactPublicKey::new(&cks);
|
||||
|
||||
(cks, sks, cpk)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_re_rand() {
|
||||
let params = V1_5_META_PARAM_CPU_2_2_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128;
|
||||
let (cks, sks, cpk) = setup_re_rand_test(params);
|
||||
|
||||
set_server_key(sks.decompress());
|
||||
|
||||
execute_re_rand_test(&cks, &cpk);
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
#[test]
|
||||
fn test_gpu_re_rand() {
|
||||
let params = V1_5_META_PARAM_GPU_2_2_MULTI_BIT_GROUP_4_KS_PBS_PKE_TO_BIG_ZKV2_TUNIFORM_2M128;
|
||||
let (cks, sks, cpk) = setup_re_rand_test(params);
|
||||
|
||||
set_server_key(sks.decompress_to_gpu());
|
||||
|
||||
execute_re_rand_test(&cks, &cpk);
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ fn test_ciphertext_re_randomization_after_compression() {
|
||||
|
||||
assert_ne!(decompressed, re_randomized);
|
||||
|
||||
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
|
||||
let decrypted: i128 = cks.decrypt_signed_radix(&re_randomized);
|
||||
assert_eq!(decrypted, message);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ pub mod compact_list;
|
||||
pub mod compressed_ciphertext_list;
|
||||
pub mod compressed_noise_squashed_ciphertext_list;
|
||||
pub mod info;
|
||||
pub mod re_randomization;
|
||||
pub mod squashed_noise;
|
||||
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
|
||||
202
tfhe/src/integer/gpu/ciphertext/re_randomization.rs
Normal file
202
tfhe/src/integer/gpu/ciphertext/re_randomization.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use crate::core_crypto::commons::generators::NoiseRandomGenerator;
|
||||
use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::{
|
||||
encrypt_lwe_compact_ciphertext_list_with_compact_public_key, LweCompactCiphertextList,
|
||||
PlaintextCount, PlaintextList,
|
||||
};
|
||||
use crate::integer::ciphertext::ReRandomizationSeed;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::cuda_backend_rerand_assign;
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
use crate::integer::CompactPublicKey;
|
||||
use crate::shortint::ciphertext::NoiseLevel;
|
||||
use crate::shortint::PBSOrder;
|
||||
use tfhe_csprng::generators::DefaultRandomGenerator;
|
||||
|
||||
impl CudaUnsignedRadixCiphertext {
|
||||
pub fn re_randomize(
|
||||
&mut self,
|
||||
compact_public_key: &CompactPublicKey,
|
||||
key_switching_key_material: &CudaKeySwitchingKeyMaterial,
|
||||
seed: ReRandomizationSeed,
|
||||
streams: &CudaStreams,
|
||||
) -> crate::Result<()> {
|
||||
self.ciphertext.re_randomize(
|
||||
compact_public_key,
|
||||
key_switching_key_material,
|
||||
seed,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaSignedRadixCiphertext {
|
||||
pub fn re_randomize(
|
||||
&mut self,
|
||||
compact_public_key: &CompactPublicKey,
|
||||
key_switching_key_material: &CudaKeySwitchingKeyMaterial,
|
||||
seed: ReRandomizationSeed,
|
||||
streams: &CudaStreams,
|
||||
) -> crate::Result<()> {
|
||||
self.ciphertext.re_randomize(
|
||||
compact_public_key,
|
||||
key_switching_key_material,
|
||||
seed,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaBooleanBlock {
|
||||
pub fn re_randomize(
|
||||
&mut self,
|
||||
compact_public_key: &CompactPublicKey,
|
||||
key_switching_key_material: &CudaKeySwitchingKeyMaterial,
|
||||
seed: ReRandomizationSeed,
|
||||
streams: &CudaStreams,
|
||||
) -> crate::Result<()> {
|
||||
self.0.re_randomize(
|
||||
compact_public_key,
|
||||
key_switching_key_material,
|
||||
seed,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaRadixCiphertext {
|
||||
pub fn re_randomize(
|
||||
&mut self,
|
||||
compact_public_key: &CompactPublicKey,
|
||||
key_switching_key_material: &CudaKeySwitchingKeyMaterial,
|
||||
seed: ReRandomizationSeed,
|
||||
streams: &CudaStreams,
|
||||
) -> crate::Result<()> {
|
||||
let ksk_pbs_order = key_switching_key_material.destination_key.into_pbs_order();
|
||||
let ksk_output_lwe_size = key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.output_key_lwe_size();
|
||||
|
||||
if let Some(msg) = self.info.blocks.iter().find_map(|ct| {
|
||||
if ct.atomic_pattern.pbs_order() != ksk_pbs_order {
|
||||
println!("{:?} {:?}", ct.atomic_pattern.pbs_order(), ksk_pbs_order);
|
||||
Some(
|
||||
"Mismatched PBSOrder between Ciphertext being re-randomized and provided \
|
||||
KeySwitchingKeyMaterialView.",
|
||||
)
|
||||
} else if ksk_output_lwe_size != self.d_blocks.0.lwe_dimension.to_lwe_size() {
|
||||
Some(
|
||||
"Mismatched LweSwize between Ciphertext being re-randomized and provided \
|
||||
KeySwitchingKeyMaterialView.",
|
||||
)
|
||||
} else if ct.noise_level > NoiseLevel::NOMINAL {
|
||||
Some("Tried to re-randomize a Ciphertext with non-nominal NoiseLevel.")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
return Err(crate::error!("{}", msg));
|
||||
}
|
||||
|
||||
if ksk_pbs_order != PBSOrder::KeyswitchBootstrap {
|
||||
// message is ok since we know that ksk order == cts order
|
||||
return Err(crate::error!(
|
||||
"Tried to re-randomize a Ciphertext with unsupported PBSOrder. \
|
||||
Required PBSOrder::KeyswitchBootstrap.",
|
||||
));
|
||||
}
|
||||
|
||||
if key_switching_key_material.cast_rshift != 0 {
|
||||
return Err(crate::error!(
|
||||
"Tried to re-randomize a Ciphertext using KeySwitchingKeyMaterialView \
|
||||
with non-zero cast_rshift, this is unsupported.",
|
||||
));
|
||||
}
|
||||
|
||||
if key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.input_key_lwe_size()
|
||||
!= self.d_blocks.lwe_dimension().to_lwe_size()
|
||||
{
|
||||
return Err(crate::error!(
|
||||
"Mismatched LweDimension between provided CompactPublicKey and \
|
||||
KeySwitchingKeyMaterialView input LweDimension.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut encryption_generator =
|
||||
NoiseRandomGenerator::<DefaultRandomGenerator>::new_from_seed(seed.0);
|
||||
|
||||
let lwe_ciphertext_count = self.d_blocks.lwe_ciphertext_count();
|
||||
let mut encryption_of_zero = LweCompactCiphertextList::new(
|
||||
0,
|
||||
self.d_blocks.lwe_dimension().to_lwe_size(),
|
||||
lwe_ciphertext_count,
|
||||
self.d_blocks.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let plaintext_list = PlaintextList::new(
|
||||
0,
|
||||
PlaintextCount(encryption_of_zero.lwe_ciphertext_count().0),
|
||||
);
|
||||
|
||||
let cpk_encryption_noise_distribution = compact_public_key
|
||||
.key
|
||||
.parameters()
|
||||
.encryption_noise_distribution;
|
||||
|
||||
encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
|
||||
&compact_public_key.key.key,
|
||||
&mut encryption_of_zero,
|
||||
&plaintext_list,
|
||||
cpk_encryption_noise_distribution,
|
||||
cpk_encryption_noise_distribution,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let d_zero_lwes = CudaLweCompactCiphertextList::from_lwe_compact_ciphertext_list(
|
||||
&encryption_of_zero,
|
||||
streams,
|
||||
);
|
||||
|
||||
let first_info = self.info.blocks.first().unwrap();
|
||||
let message_modulus = first_info.message_modulus;
|
||||
let carry_modulus = first_info.carry_modulus;
|
||||
|
||||
unsafe {
|
||||
cuda_backend_rerand_assign(
|
||||
streams,
|
||||
&mut self.d_blocks,
|
||||
&d_zero_lwes,
|
||||
&key_switching_key_material.lwe_keyswitch_key,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.decomposition_level_count(),
|
||||
key_switching_key_material
|
||||
.lwe_keyswitch_key
|
||||
.decomposition_base_log(),
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
);
|
||||
}
|
||||
|
||||
self.info.blocks.iter_mut().for_each(|ct| {
|
||||
ct.noise_level = NoiseLevel::NOMINAL;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,8 @@
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::key_switching_key::KeySwitchingKey;
|
||||
use crate::shortint::EncryptionKeyChoice;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CudaKeySwitchingKeyMaterial {
|
||||
pub(crate) lwe_keyswitch_key: CudaLweKeyswitchKey<u64>,
|
||||
pub(crate) destination_key: EncryptionKeyChoice,
|
||||
}
|
||||
use crate::integer::key_switching_key::{KeySwitchingKey, KeySwitchingKeyMaterialView};
|
||||
pub use crate::shortint::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct CudaKeySwitchingKey<'key> {
|
||||
@@ -27,9 +20,30 @@ impl CudaKeySwitchingKeyMaterial {
|
||||
&key_switching_key_material.key_switching_key,
|
||||
streams,
|
||||
);
|
||||
let cast_rshift = key_switching_key_material.cast_rshift;
|
||||
|
||||
Self {
|
||||
lwe_keyswitch_key: d_lwe_keyswich_key,
|
||||
destination_key: key_switching_key_material.destination_key,
|
||||
cast_rshift,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_key_switching_key_material(
|
||||
key_switching_key_material: &KeySwitchingKeyMaterialView,
|
||||
streams: &CudaStreams,
|
||||
) -> Self {
|
||||
let lwe_keyswitch_key = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
key_switching_key_material.material.key_switching_key,
|
||||
streams,
|
||||
);
|
||||
let destination_key = key_switching_key_material.material.destination_key;
|
||||
let cast_rshift = key_switching_key_material.material.cast_rshift;
|
||||
|
||||
Self {
|
||||
lwe_keyswitch_key,
|
||||
cast_rshift,
|
||||
destination_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ pub mod zk;
|
||||
|
||||
use crate::core_crypto::gpu::lwe_bootstrap_key::CudaModulusSwitchNoiseReductionConfiguration;
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList;
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
|
||||
use crate::core_crypto::gpu::vec::CudaVec;
|
||||
use crate::core_crypto::gpu::{CudaStreams, PBSMSNoiseReductionType};
|
||||
@@ -5169,6 +5171,69 @@ pub(crate) unsafe fn cuda_backend_unchecked_cmux<T: UnsignedInteger, B: Numeric>
|
||||
update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_rerand_assign<T: UnsignedInteger>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array: &mut CudaLweCiphertextList<T>,
|
||||
zero_lwes: &CudaLweCompactCiphertextList<T>,
|
||||
keyswitch_key: &CudaLweKeyswitchKey<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
) {
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
lwe_array.0.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
lwe_array.0.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
zero_lwes.0.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
zero_lwes.0.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
keyswitch_key.d_vec.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
keyswitch_key.d_vec.gpu_index(0).get(),
|
||||
);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
scratch_cuda_rerand_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
true,
|
||||
);
|
||||
cuda_rerand_64(
|
||||
streams.ffi(),
|
||||
lwe_array.0.d_vec.as_mut_c_ptr(0),
|
||||
zero_lwes.0.d_vec.as_c_ptr(0),
|
||||
mem_ptr,
|
||||
keyswitch_key.d_vec.ptr.as_ptr(),
|
||||
);
|
||||
cleanup_cuda_rerand(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
}
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn cuda_backend_get_cmux_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
|
||||
@@ -7,7 +7,19 @@
|
||||
//! - Then a pbs is done to update the encoding, if the parameters do not have the same precision.
|
||||
//! This allows to apply a user provided function at the same time.
|
||||
|
||||
use super::atomic_pattern::AtomicPatternServerKey;
|
||||
use super::backward_compatibility::key_switching_key::{
|
||||
CompressedKeySwitchingKeyMaterialVersions, CompressedKeySwitchingKeyVersions,
|
||||
KeySwitchingKeyDestinationAtomicPatternVersions, KeySwitchingKeyMaterialVersions,
|
||||
KeySwitchingKeyVersions,
|
||||
};
|
||||
use super::server_key::{
|
||||
KS32ServerKeyView, ServerKeyView, ShortintBootstrappingKey, StandardServerKeyView,
|
||||
};
|
||||
use super::AtomicPatternKind;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
use crate::core_crypto::prelude::{
|
||||
keyswitch_lwe_ciphertext, CastFrom, CastInto, Cleartext, LweCiphertext, LweCiphertextOwned,
|
||||
LweKeyswitchKeyConformanceParams, LweKeyswitchKeyOwned, SeededLweKeyswitchKeyOwned,
|
||||
@@ -28,17 +40,6 @@ use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use super::atomic_pattern::AtomicPatternServerKey;
|
||||
use super::backward_compatibility::key_switching_key::{
|
||||
CompressedKeySwitchingKeyMaterialVersions, CompressedKeySwitchingKeyVersions,
|
||||
KeySwitchingKeyDestinationAtomicPatternVersions, KeySwitchingKeyMaterialVersions,
|
||||
KeySwitchingKeyVersions,
|
||||
};
|
||||
use super::server_key::{
|
||||
KS32ServerKeyView, ServerKeyView, ShortintBootstrappingKey, StandardServerKeyView,
|
||||
};
|
||||
use super::AtomicPatternKind;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
@@ -255,6 +256,15 @@ pub struct KeySwitchingKeyMaterialView<'key> {
|
||||
pub(crate) destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
#[cfg(feature = "gpu")]
|
||||
pub struct CudaKeySwitchingKeyMaterial {
|
||||
pub(crate) lwe_keyswitch_key: CudaLweKeyswitchKey<u64>,
|
||||
pub(crate) cast_rshift: i8,
|
||||
pub(crate) destination_key: EncryptionKeyChoice,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub struct KeySwitchingKeyView<'keys> {
|
||||
pub(crate) key_switching_key_material: KeySwitchingKeyMaterialView<'keys>,
|
||||
|
||||
Reference in New Issue
Block a user